436 lines
14 KiB
Python
436 lines
14 KiB
Python
import os
|
|
import json
|
|
import threading
|
|
import socketserver
|
|
import base64
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from Crypto.PublicKey import RSA
|
|
from Crypto.Cipher import PKCS1_OAEP, AES
|
|
from Crypto.Random import get_random_bytes
|
|
from Crypto.Util.number import GCD
|
|
from phe import paillier
|
|
|
|
DATA_DIR = "server_data"
|
|
DOCTORS_FILE = os.path.join(DATA_DIR, "doctors.json")
|
|
EXPENSES_FILE = os.path.join(DATA_DIR, "expenses.json")
|
|
REPORTS_FILE = os.path.join(DATA_DIR, "reports.json")
|
|
CONF_FILE = os.path.join(DATA_DIR, "config.json")
|
|
RSA_PRIV_FILE = os.path.join(DATA_DIR, "server_rsa_priv.pem")
|
|
RSA_PUB_FILE = os.path.join(DATA_DIR, "server_rsa_pub.pem")
|
|
PORT = 5000
|
|
|
|
lock = threading.Lock()
|
|
|
|
def ensure_dirs():
|
|
os.makedirs(DATA_DIR, exist_ok=True)
|
|
|
|
def read_json(path, default):
|
|
if not os.path.exists(path):
|
|
return default
|
|
with open(path, "r") as f:
|
|
return json.load(f)
|
|
|
|
def write_json(path, obj):
|
|
tmp = path + ".tmp"
|
|
with open(tmp, "w") as f:
|
|
json.dump(obj, f, indent=2)
|
|
os.replace(tmp, path)
|
|
|
|
def load_or_create_rsa():
|
|
if not os.path.exists(RSA_PRIV_FILE):
|
|
key = RSA.generate(2048)
|
|
with open(RSA_PRIV_FILE, "wb") as f:
|
|
f.write(key.export_key())
|
|
with open(RSA_PUB_FILE, "wb") as f:
|
|
f.write(key.public_key().export_key())
|
|
with open(RSA_PRIV_FILE, "rb") as f:
|
|
priv = RSA.import_key(f.read())
|
|
with open(RSA_PUB_FILE, "rb") as f:
|
|
pub = RSA.import_key(f.read())
|
|
return priv, pub
|
|
|
|
def load_or_create_paillier():
|
|
conf = read_json(CONF_FILE, {})
|
|
if "paillier" not in conf:
|
|
pubkey, privkey = paillier.generate_paillier_keypair()
|
|
conf["paillier"] = {
|
|
"n": str(pubkey.n),
|
|
"p": str(privkey.p),
|
|
"q": str(privkey.q),
|
|
}
|
|
write_json(CONF_FILE, conf)
|
|
conf = read_json(CONF_FILE, {})
|
|
n = int(conf["paillier"]["n"])
|
|
p = int(conf["paillier"]["p"])
|
|
q = int(conf["paillier"]["q"])
|
|
pubkey = paillier.PaillierPublicKey(n)
|
|
privkey = paillier.PaillierPrivateKey(pubkey, p, q)
|
|
return pubkey, privkey
|
|
|
|
def load_or_create_config_rsa_homomorphic_base(rsa_pub):
|
|
conf = read_json(CONF_FILE, {})
|
|
n = rsa_pub.n
|
|
if "rsa_homomorphic" not in conf:
|
|
# pick base g coprime to n
|
|
import random
|
|
while True:
|
|
g = random.randrange(2, n - 1)
|
|
if GCD(g, n) == 1:
|
|
break
|
|
conf["rsa_homomorphic"] = {
|
|
"g": str(g)
|
|
}
|
|
write_json(CONF_FILE, conf)
|
|
conf = read_json(CONF_FILE, {})
|
|
g = int(conf["rsa_homomorphic"]["g"])
|
|
return g
|
|
|
|
def b64e(b: bytes) -> str:
|
|
return base64.b64encode(b).decode()
|
|
|
|
def b64d(s: str) -> bytes:
|
|
return base64.b64decode(s.encode())
|
|
|
|
def init_storage():
|
|
ensure_dirs()
|
|
priv, pub = load_or_create_rsa()
|
|
_ = load_or_create_paillier()
|
|
if not os.path.exists(DOCTORS_FILE):
|
|
write_json(DOCTORS_FILE, {})
|
|
if not os.path.exists(EXPENSES_FILE):
|
|
write_json(EXPENSES_FILE, [])
|
|
if not os.path.exists(REPORTS_FILE):
|
|
write_json(REPORTS_FILE, [])
|
|
return priv, pub
|
|
|
|
RSA_PRIV, RSA_PUB = init_storage()
|
|
PAI_PUB, PAI_PRIV = load_or_create_paillier()
|
|
RSA_HOMO_G = load_or_create_config_rsa_homomorphic_base(RSA_PUB)
|
|
|
|
def get_public_info():
|
|
return {
|
|
"rsa_pub_pem_b64": b64e(RSA_PUB.export_key()),
|
|
"rsa_n": str(RSA_PUB.n),
|
|
"rsa_e": str(RSA_PUB.e),
|
|
"paillier_n": str(PAI_PUB.n),
|
|
"rsa_homo_g": str(RSA_HOMO_G),
|
|
}
|
|
|
|
def handle_register_doctor(body):
|
|
# body: {doctor_id, department_plain, dept_enc: {ciphertext, exponent}, elgamal_pub: {p,g,y}}
|
|
doc_id = body.get("doctor_id","").strip()
|
|
dept_plain = body.get("department_plain","").strip()
|
|
dept_enc = body.get("dept_enc")
|
|
elgamal_pub = body.get("elgamal_pub")
|
|
if not doc_id or not doc_id.isalnum():
|
|
return {"status":"error","error":"invalid doctor_id"}
|
|
if not dept_plain:
|
|
return {"status":"error","error":"invalid department"}
|
|
if not dept_enc or "ciphertext" not in dept_enc or "exponent" not in dept_enc:
|
|
return {"status":"error","error":"invalid dept_enc"}
|
|
if not elgamal_pub or not all(k in elgamal_pub for k in ["p","g","y"]):
|
|
return {"status":"error","error":"missing elgamal_pub"}
|
|
|
|
with lock:
|
|
doctors = read_json(DOCTORS_FILE, {})
|
|
doctors[doc_id] = {
|
|
"department_plain": dept_plain,
|
|
"dept_enc": {
|
|
"ciphertext": str(int(dept_enc["ciphertext"])),
|
|
"exponent": int(dept_enc["exponent"])
|
|
},
|
|
"elgamal_pub": {
|
|
"p": str(int(elgamal_pub["p"])),
|
|
"g": str(int(elgamal_pub["g"])),
|
|
"y": str(int(elgamal_pub["y"]))
|
|
}
|
|
}
|
|
write_json(DOCTORS_FILE, doctors)
|
|
print(f"[server] registered doctor {doc_id} dept='{dept_plain}' (stored encrypted and plaintext)")
|
|
return {"status":"ok"}
|
|
|
|
def handle_upload_report(body):
|
|
# body: {doctor_id, filename, timestamp, md5_hex, sig: {r,s}, aes: {key_rsa_oaep_b64, nonce_b64, tag_b64, ct_b64}}
|
|
doc_id = body.get("doctor_id","").strip()
|
|
filename = os.path.basename(body.get("filename","").strip())
|
|
timestamp = body.get("timestamp","").strip()
|
|
md5_hex = body.get("md5_hex","").strip()
|
|
sig = body.get("sig")
|
|
aes = body.get("aes")
|
|
if not doc_id or not filename or not timestamp or not md5_hex or not sig or not aes:
|
|
return {"status":"error","error":"missing fields"}
|
|
|
|
with lock:
|
|
doctors = read_json(DOCTORS_FILE, {})
|
|
if doc_id not in doctors:
|
|
return {"status":"error","error":"unknown doctor_id"}
|
|
|
|
# decrypt AES key
|
|
try:
|
|
rsa_cipher = PKCS1_OAEP.new(RSA_PRIV)
|
|
aes_key = rsa_cipher.decrypt(b64d(aes["key_rsa_oaep_b64"]))
|
|
nonce = b64d(aes["nonce_b64"])
|
|
tag = b64d(aes["tag_b64"])
|
|
ct = b64d(aes["ct_b64"])
|
|
aes_cipher = AES.new(aes_key, AES.MODE_EAX, nonce=nonce)
|
|
report_bytes = aes_cipher.decrypt_and_verify(ct, tag)
|
|
except Exception as e:
|
|
return {"status":"error","error":f"aes/rsa decrypt failed: {e}"}
|
|
|
|
# verify MD5
|
|
import hashlib
|
|
md5_check = hashlib.md5(report_bytes).hexdigest()
|
|
if md5_check != md5_hex:
|
|
print("[server] md5 mismatch")
|
|
# store file
|
|
outdir = os.path.join(DATA_DIR, "reports")
|
|
os.makedirs(outdir, exist_ok=True)
|
|
savepath = os.path.join(outdir, f"{doc_id}_{int(time.time())}_{filename}")
|
|
with open(savepath, "wb") as f:
|
|
f.write(report_bytes)
|
|
|
|
# store record
|
|
rec = {
|
|
"doctor_id": doc_id,
|
|
"filename": filename,
|
|
"saved_path": savepath,
|
|
"timestamp": timestamp,
|
|
"md5_hex": md5_hex,
|
|
"sig": {"r": str(int(sig["r"])), "s": str(int(sig["s"]))}
|
|
}
|
|
with lock:
|
|
records = read_json(REPORTS_FILE, [])
|
|
records.append(rec)
|
|
write_json(REPORTS_FILE, records)
|
|
print(f"[server] report uploaded by {doc_id}, stored {savepath}")
|
|
return {"status":"ok"}
|
|
|
|
def handle_submit_expense(body):
|
|
# body: {doctor_id, amount_ciphertext}
|
|
doc_id = body.get("doctor_id","").strip()
|
|
c = body.get("amount_ciphertext")
|
|
if not doc_id or not doc_id.isalnum():
|
|
return {"status":"error","error":"invalid doctor_id"}
|
|
try:
|
|
c_int = int(c)
|
|
except:
|
|
return {"status":"error","error":"invalid ciphertext"}
|
|
with lock:
|
|
doctors = read_json(DOCTORS_FILE, {})
|
|
if doc_id not in doctors:
|
|
return {"status":"error","error":"unknown doctor_id"}
|
|
|
|
with lock:
|
|
expenses = read_json(EXPENSES_FILE, [])
|
|
expenses.append({"doctor_id": doc_id, "ciphertext": str(c_int)})
|
|
write_json(EXPENSES_FILE, expenses)
|
|
print(f"[server] expense ciphertext stored for {doc_id}")
|
|
return {"status":"ok"}
|
|
|
|
class RequestHandler(socketserver.StreamRequestHandler):
|
|
def handle(self):
|
|
try:
|
|
data = self.rfile.readline()
|
|
if not data:
|
|
return
|
|
req = json.loads(data.decode())
|
|
action = req.get("action")
|
|
role = req.get("role", "")
|
|
body = req.get("body", {})
|
|
if action == "get_public_info":
|
|
resp = {"status":"ok","data": get_public_info()}
|
|
elif action == "register_doctor":
|
|
if role != "doctor":
|
|
resp = {"status":"error","error":"unauthorized"}
|
|
else:
|
|
resp = handle_register_doctor(body)
|
|
elif action == "upload_report":
|
|
if role != "doctor":
|
|
resp = {"status":"error","error":"unauthorized"}
|
|
else:
|
|
resp = handle_upload_report(body)
|
|
elif action == "submit_expense":
|
|
if role != "doctor":
|
|
resp = {"status":"error","error":"unauthorized"}
|
|
else:
|
|
resp = handle_submit_expense(body)
|
|
else:
|
|
resp = {"status":"error","error":"unknown action"}
|
|
except Exception as e:
|
|
resp = {"status":"error","error":str(e)}
|
|
self.wfile.write((json.dumps(resp)+"\n").encode())
|
|
|
|
def start_server():
|
|
server = socketserver.ThreadingTCPServer(("127.0.0.1", PORT), RequestHandler)
|
|
t = threading.Thread(target=server.serve_forever, daemon=True)
|
|
t.start()
|
|
print(f"[server] listening on 127.0.0.1:{PORT}")
|
|
return server
|
|
|
|
# Auditor utilities
|
|
|
|
def load_doctors():
|
|
return read_json(DOCTORS_FILE, {})
|
|
|
|
def load_expenses():
|
|
return read_json(EXPENSES_FILE, [])
|
|
|
|
def load_reports():
|
|
return read_json(REPORTS_FILE, [])
|
|
|
|
def audit_list_doctors():
|
|
docs = load_doctors()
|
|
print("Doctors:")
|
|
for did, info in docs.items():
|
|
enc = info["dept_enc"]
|
|
print(f"- {did} dept_plain='{info['department_plain']}' enc_ciphertext={enc['ciphertext']} exponent={enc['exponent']}")
|
|
|
|
def audit_keyword_search():
|
|
docs = load_doctors()
|
|
if not docs:
|
|
print("no doctors")
|
|
return
|
|
q = input("Enter department keyword to search: ").strip()
|
|
if not q:
|
|
print("empty")
|
|
return
|
|
# hash to int
|
|
import hashlib
|
|
h = int.from_bytes(hashlib.sha256(q.encode()).digest(), "big")
|
|
pub = PAI_PUB
|
|
priv = PAI_PRIV
|
|
enc_q = pub.encrypt(h)
|
|
print("Matching doctors (using Paillier equality on hashed dept):")
|
|
for did, info in docs.items():
|
|
enc = info["dept_enc"]
|
|
c = int(enc["ciphertext"])
|
|
exp = int(enc["exponent"])
|
|
enc_doc = paillier.EncryptedNumber(pub, c, exp)
|
|
diff = enc_doc - enc_q
|
|
dec = priv.decrypt(diff)
|
|
match = (dec == 0)
|
|
print(f" {did}: dept_plain='{info['department_plain']}' enc_ciphertext={c} match={match}")
|
|
|
|
def rsa_homo_decrypt_sum(c_prod_int):
|
|
n = RSA_PRIV.n
|
|
d = RSA_PRIV.d
|
|
g = RSA_HOMO_G
|
|
# decrypt to get g^sum mod n
|
|
m = pow(int(c_prod_int), d, n)
|
|
# brute force discrete log for small sums
|
|
max_iter = 500000
|
|
acc = 1
|
|
for k in range(0, max_iter+1):
|
|
if acc == m:
|
|
return k
|
|
acc = (acc * g) % n
|
|
return None
|
|
|
|
def audit_sum_expenses():
|
|
exps = load_expenses()
|
|
if not exps:
|
|
print("no expenses")
|
|
return
|
|
# sum all
|
|
n = RSA_PUB.n
|
|
c_prod = 1
|
|
for e in exps:
|
|
c_prod = (c_prod * int(e["ciphertext"])) % n
|
|
print(f"Product ciphertext (represents sum under RSA-in-exponent): {c_prod}")
|
|
s = rsa_homo_decrypt_sum(c_prod)
|
|
if s is None:
|
|
print("sum decryption failed (exceeded search bound)")
|
|
else:
|
|
print(f"Decrypted sum of expenses = {s}")
|
|
# by doctor
|
|
docs = load_doctors()
|
|
if docs:
|
|
print("Per-doctor sums:")
|
|
for did in docs.keys():
|
|
c_prod_d = 1
|
|
count = 0
|
|
for e in exps:
|
|
if e["doctor_id"] == did:
|
|
c_prod_d = (c_prod_d * int(e["ciphertext"])) % n
|
|
count += 1
|
|
if count == 0:
|
|
continue
|
|
s_d = rsa_homo_decrypt_sum(c_prod_d)
|
|
print(f" {did}: entries={count} product_ct={c_prod_d} sum={s_d}")
|
|
|
|
def elgamal_verify(p, g, y, H_int, r, s):
|
|
# verify: g^H ≡ y^r * r^s (mod p)
|
|
return pow(g, H_int, p) == (pow(y, r, p) * pow(r, s, p)) % p
|
|
|
|
def audit_verify_reports():
|
|
records = load_reports()
|
|
if not records:
|
|
print("no reports")
|
|
return
|
|
doctors = load_doctors()
|
|
for rec in records:
|
|
did = rec["doctor_id"]
|
|
docinfo = doctors.get(did)
|
|
ok_sig = False
|
|
ok_ts = False
|
|
if docinfo:
|
|
p = int(docinfo["elgamal_pub"]["p"])
|
|
g = int(docinfo["elgamal_pub"]["g"])
|
|
y = int(docinfo["elgamal_pub"]["y"])
|
|
r = int(rec["sig"]["r"])
|
|
s = int(rec["sig"]["s"])
|
|
try:
|
|
with open(rec["saved_path"], "rb") as f:
|
|
report_bytes = f.read()
|
|
import hashlib
|
|
H = int.from_bytes(hashlib.md5(report_bytes + rec["timestamp"].encode()).digest(), "big") % (p - 1)
|
|
ok_sig = elgamal_verify(p, g, y, H, r, s)
|
|
except Exception as e:
|
|
ok_sig = False
|
|
# timestamp check
|
|
try:
|
|
ts = datetime.fromisoformat(rec["timestamp"])
|
|
except:
|
|
try:
|
|
ts = datetime.strptime(rec["timestamp"], "%Y-%m-%dT%H:%M:%S.%f")
|
|
except:
|
|
ts = None
|
|
if ts:
|
|
now = datetime.utcnow().replace(tzinfo=None)
|
|
delta = (now - ts).total_seconds()
|
|
# simple rule: not in the future by more than 5 min
|
|
ok_ts = (delta >= -300)
|
|
print(f"- report by {did} file={os.path.basename(rec['saved_path'])} sig_ok={ok_sig} ts_ok={ok_ts} ts={rec['timestamp']} md5={rec['md5_hex']}")
|
|
|
|
def auditor_menu():
|
|
while True:
|
|
print("\n[Auditor Menu]")
|
|
print("1) List doctors (show encrypted and plaintext dept)")
|
|
print("2) Keyword search doctors by dept (Paillier)")
|
|
print("3) Sum expenses (RSA-in-exponent demo)")
|
|
print("4) Verify reports and timestamps")
|
|
print("5) Show server public info")
|
|
print("0) Exit")
|
|
ch = input("Select: ").strip()
|
|
if ch == "1":
|
|
audit_list_doctors()
|
|
elif ch == "2":
|
|
audit_keyword_search()
|
|
elif ch == "3":
|
|
audit_sum_expenses()
|
|
elif ch == "4":
|
|
audit_verify_reports()
|
|
elif ch == "5":
|
|
info = get_public_info()
|
|
print(json.dumps(info, indent=2))
|
|
elif ch == "0":
|
|
print("bye")
|
|
break
|
|
else:
|
|
print("invalid")
|
|
|
|
if __name__ == "__main__":
|
|
start_server()
|
|
auditor_menu()
|
|
|