436 lines
		
	
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			436 lines
		
	
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| import json
 | |
| import threading
 | |
| import socketserver
 | |
| import base64
 | |
| import time
 | |
| from datetime import datetime, timezone
 | |
| from Crypto.PublicKey import RSA
 | |
| from Crypto.Cipher import PKCS1_OAEP, AES
 | |
| from Crypto.Random import get_random_bytes
 | |
| from Crypto.Util.number import GCD
 | |
| from phe import paillier
 | |
| 
 | |
| DATA_DIR = "server_data"
 | |
| DOCTORS_FILE = os.path.join(DATA_DIR, "doctors.json")
 | |
| EXPENSES_FILE = os.path.join(DATA_DIR, "expenses.json")
 | |
| REPORTS_FILE = os.path.join(DATA_DIR, "reports.json")
 | |
| CONF_FILE = os.path.join(DATA_DIR, "config.json")
 | |
| RSA_PRIV_FILE = os.path.join(DATA_DIR, "server_rsa_priv.pem")
 | |
| RSA_PUB_FILE = os.path.join(DATA_DIR, "server_rsa_pub.pem")
 | |
| PORT = 5000
 | |
| 
 | |
| lock = threading.Lock()
 | |
| 
 | |
| def ensure_dirs():
 | |
|     os.makedirs(DATA_DIR, exist_ok=True)
 | |
| 
 | |
| def read_json(path, default):
 | |
|     if not os.path.exists(path):
 | |
|         return default
 | |
|     with open(path, "r") as f:
 | |
|         return json.load(f)
 | |
| 
 | |
| def write_json(path, obj):
 | |
|     tmp = path + ".tmp"
 | |
|     with open(tmp, "w") as f:
 | |
|         json.dump(obj, f, indent=2)
 | |
|     os.replace(tmp, path)
 | |
| 
 | |
| def load_or_create_rsa():
 | |
|     if not os.path.exists(RSA_PRIV_FILE):
 | |
|         key = RSA.generate(2048)
 | |
|         with open(RSA_PRIV_FILE, "wb") as f:
 | |
|             f.write(key.export_key())
 | |
|         with open(RSA_PUB_FILE, "wb") as f:
 | |
|             f.write(key.public_key().export_key())
 | |
|     with open(RSA_PRIV_FILE, "rb") as f:
 | |
|         priv = RSA.import_key(f.read())
 | |
|     with open(RSA_PUB_FILE, "rb") as f:
 | |
|         pub = RSA.import_key(f.read())
 | |
|     return priv, pub
 | |
| 
 | |
| def load_or_create_paillier():
 | |
|     conf = read_json(CONF_FILE, {})
 | |
|     if "paillier" not in conf:
 | |
|         pubkey, privkey = paillier.generate_paillier_keypair()
 | |
|         conf["paillier"] = {
 | |
|             "n": str(pubkey.n),
 | |
|             "p": str(privkey.p),
 | |
|             "q": str(privkey.q),
 | |
|         }
 | |
|         write_json(CONF_FILE, conf)
 | |
|     conf = read_json(CONF_FILE, {})
 | |
|     n = int(conf["paillier"]["n"])
 | |
|     p = int(conf["paillier"]["p"])
 | |
|     q = int(conf["paillier"]["q"])
 | |
|     pubkey = paillier.PaillierPublicKey(n)
 | |
|     privkey = paillier.PaillierPrivateKey(pubkey, p, q)
 | |
|     return pubkey, privkey
 | |
| 
 | |
| def load_or_create_config_rsa_homomorphic_base(rsa_pub):
 | |
|     conf = read_json(CONF_FILE, {})
 | |
|     n = rsa_pub.n
 | |
|     if "rsa_homomorphic" not in conf:
 | |
|         # pick base g coprime to n
 | |
|         import random
 | |
|         while True:
 | |
|             g = random.randrange(2, n - 1)
 | |
|             if GCD(g, n) == 1:
 | |
|                 break
 | |
|         conf["rsa_homomorphic"] = {
 | |
|             "g": str(g)
 | |
|         }
 | |
|         write_json(CONF_FILE, conf)
 | |
|     conf = read_json(CONF_FILE, {})
 | |
|     g = int(conf["rsa_homomorphic"]["g"])
 | |
|     return g
 | |
| 
 | |
| def b64e(b: bytes) -> str:
 | |
|     return base64.b64encode(b).decode()
 | |
| 
 | |
| def b64d(s: str) -> bytes:
 | |
|     return base64.b64decode(s.encode())
 | |
| 
 | |
| def init_storage():
 | |
|     ensure_dirs()
 | |
|     priv, pub = load_or_create_rsa()
 | |
|     _ = load_or_create_paillier()
 | |
|     if not os.path.exists(DOCTORS_FILE):
 | |
|         write_json(DOCTORS_FILE, {})
 | |
|     if not os.path.exists(EXPENSES_FILE):
 | |
|         write_json(EXPENSES_FILE, [])
 | |
|     if not os.path.exists(REPORTS_FILE):
 | |
|         write_json(REPORTS_FILE, [])
 | |
|     return priv, pub
 | |
| 
 | |
| RSA_PRIV, RSA_PUB = init_storage()
 | |
| PAI_PUB, PAI_PRIV = load_or_create_paillier()
 | |
| RSA_HOMO_G = load_or_create_config_rsa_homomorphic_base(RSA_PUB)
 | |
| 
 | |
| def get_public_info():
 | |
|     return {
 | |
|         "rsa_pub_pem_b64": b64e(RSA_PUB.export_key()),
 | |
|         "rsa_n": str(RSA_PUB.n),
 | |
|         "rsa_e": str(RSA_PUB.e),
 | |
|         "paillier_n": str(PAI_PUB.n),
 | |
|         "rsa_homo_g": str(RSA_HOMO_G),
 | |
|     }
 | |
| 
 | |
| def handle_register_doctor(body):
 | |
|     # body: {doctor_id, department_plain, dept_enc: {ciphertext, exponent}, elgamal_pub: {p,g,y}}
 | |
|     doc_id = body.get("doctor_id","").strip()
 | |
|     dept_plain = body.get("department_plain","").strip()
 | |
|     dept_enc = body.get("dept_enc")
 | |
|     elgamal_pub = body.get("elgamal_pub")
 | |
|     if not doc_id or not doc_id.isalnum():
 | |
|         return {"status":"error","error":"invalid doctor_id"}
 | |
|     if not dept_plain:
 | |
|         return {"status":"error","error":"invalid department"}
 | |
|     if not dept_enc or "ciphertext" not in dept_enc or "exponent" not in dept_enc:
 | |
|         return {"status":"error","error":"invalid dept_enc"}
 | |
|     if not elgamal_pub or not all(k in elgamal_pub for k in ["p","g","y"]):
 | |
|         return {"status":"error","error":"missing elgamal_pub"}
 | |
| 
 | |
|     with lock:
 | |
|         doctors = read_json(DOCTORS_FILE, {})
 | |
|         doctors[doc_id] = {
 | |
|             "department_plain": dept_plain,
 | |
|             "dept_enc": {
 | |
|                 "ciphertext": str(int(dept_enc["ciphertext"])),
 | |
|                 "exponent": int(dept_enc["exponent"])
 | |
|             },
 | |
|             "elgamal_pub": {
 | |
|                 "p": str(int(elgamal_pub["p"])),
 | |
|                 "g": str(int(elgamal_pub["g"])),
 | |
|                 "y": str(int(elgamal_pub["y"]))
 | |
|             }
 | |
|         }
 | |
|         write_json(DOCTORS_FILE, doctors)
 | |
|     print(f"[server] registered doctor {doc_id} dept='{dept_plain}' (stored encrypted and plaintext)")
 | |
|     return {"status":"ok"}
 | |
| 
 | |
| def handle_upload_report(body):
 | |
|     # body: {doctor_id, filename, timestamp, md5_hex, sig: {r,s}, aes: {key_rsa_oaep_b64, nonce_b64, tag_b64, ct_b64}}
 | |
|     doc_id = body.get("doctor_id","").strip()
 | |
|     filename = os.path.basename(body.get("filename","").strip())
 | |
|     timestamp = body.get("timestamp","").strip()
 | |
|     md5_hex = body.get("md5_hex","").strip()
 | |
|     sig = body.get("sig")
 | |
|     aes = body.get("aes")
 | |
|     if not doc_id or not filename or not timestamp or not md5_hex or not sig or not aes:
 | |
|         return {"status":"error","error":"missing fields"}
 | |
| 
 | |
|     with lock:
 | |
|         doctors = read_json(DOCTORS_FILE, {})
 | |
|     if doc_id not in doctors:
 | |
|         return {"status":"error","error":"unknown doctor_id"}
 | |
| 
 | |
|     # decrypt AES key
 | |
|     try:
 | |
|         rsa_cipher = PKCS1_OAEP.new(RSA_PRIV)
 | |
|         aes_key = rsa_cipher.decrypt(b64d(aes["key_rsa_oaep_b64"]))
 | |
|         nonce = b64d(aes["nonce_b64"])
 | |
|         tag = b64d(aes["tag_b64"])
 | |
|         ct = b64d(aes["ct_b64"])
 | |
|         aes_cipher = AES.new(aes_key, AES.MODE_EAX, nonce=nonce)
 | |
|         report_bytes = aes_cipher.decrypt_and_verify(ct, tag)
 | |
|     except Exception as e:
 | |
|         return {"status":"error","error":f"aes/rsa decrypt failed: {e}"}
 | |
| 
 | |
|     # verify MD5
 | |
|     import hashlib
 | |
|     md5_check = hashlib.md5(report_bytes).hexdigest()
 | |
|     if md5_check != md5_hex:
 | |
|         print("[server] md5 mismatch")
 | |
|     # store file
 | |
|     outdir = os.path.join(DATA_DIR, "reports")
 | |
|     os.makedirs(outdir, exist_ok=True)
 | |
|     savepath = os.path.join(outdir, f"{doc_id}_{int(time.time())}_{filename}")
 | |
|     with open(savepath, "wb") as f:
 | |
|         f.write(report_bytes)
 | |
| 
 | |
|     # store record
 | |
|     rec = {
 | |
|         "doctor_id": doc_id,
 | |
|         "filename": filename,
 | |
|         "saved_path": savepath,
 | |
|         "timestamp": timestamp,
 | |
|         "md5_hex": md5_hex,
 | |
|         "sig": {"r": str(int(sig["r"])), "s": str(int(sig["s"]))}
 | |
|     }
 | |
|     with lock:
 | |
|         records = read_json(REPORTS_FILE, [])
 | |
|         records.append(rec)
 | |
|         write_json(REPORTS_FILE, records)
 | |
|     print(f"[server] report uploaded by {doc_id}, stored {savepath}")
 | |
|     return {"status":"ok"}
 | |
| 
 | |
| def handle_submit_expense(body):
 | |
|     # body: {doctor_id, amount_ciphertext}
 | |
|     doc_id = body.get("doctor_id","").strip()
 | |
|     c = body.get("amount_ciphertext")
 | |
|     if not doc_id or not doc_id.isalnum():
 | |
|         return {"status":"error","error":"invalid doctor_id"}
 | |
|     try:
 | |
|         c_int = int(c)
 | |
|     except:
 | |
|         return {"status":"error","error":"invalid ciphertext"}
 | |
|     with lock:
 | |
|         doctors = read_json(DOCTORS_FILE, {})
 | |
|     if doc_id not in doctors:
 | |
|         return {"status":"error","error":"unknown doctor_id"}
 | |
| 
 | |
|     with lock:
 | |
|         expenses = read_json(EXPENSES_FILE, [])
 | |
|         expenses.append({"doctor_id": doc_id, "ciphertext": str(c_int)})
 | |
|         write_json(EXPENSES_FILE, expenses)
 | |
|     print(f"[server] expense ciphertext stored for {doc_id}")
 | |
|     return {"status":"ok"}
 | |
| 
 | |
| class RequestHandler(socketserver.StreamRequestHandler):
 | |
|     def handle(self):
 | |
|         try:
 | |
|             data = self.rfile.readline()
 | |
|             if not data:
 | |
|                 return
 | |
|             req = json.loads(data.decode())
 | |
|             action = req.get("action")
 | |
|             role = req.get("role", "")
 | |
|             body = req.get("body", {})
 | |
|             if action == "get_public_info":
 | |
|                 resp = {"status":"ok","data": get_public_info()}
 | |
|             elif action == "register_doctor":
 | |
|                 if role != "doctor":
 | |
|                     resp = {"status":"error","error":"unauthorized"}
 | |
|                 else:
 | |
|                     resp = handle_register_doctor(body)
 | |
|             elif action == "upload_report":
 | |
|                 if role != "doctor":
 | |
|                     resp = {"status":"error","error":"unauthorized"}
 | |
|                 else:
 | |
|                     resp = handle_upload_report(body)
 | |
|             elif action == "submit_expense":
 | |
|                 if role != "doctor":
 | |
|                     resp = {"status":"error","error":"unauthorized"}
 | |
|                 else:
 | |
|                     resp = handle_submit_expense(body)
 | |
|             else:
 | |
|                 resp = {"status":"error","error":"unknown action"}
 | |
|         except Exception as e:
 | |
|             resp = {"status":"error","error":str(e)}
 | |
|         self.wfile.write((json.dumps(resp)+"\n").encode())
 | |
| 
 | |
| def start_server():
 | |
|     server = socketserver.ThreadingTCPServer(("127.0.0.1", PORT), RequestHandler)
 | |
|     t = threading.Thread(target=server.serve_forever, daemon=True)
 | |
|     t.start()
 | |
|     print(f"[server] listening on 127.0.0.1:{PORT}")
 | |
|     return server
 | |
| 
 | |
| # Auditor utilities
 | |
| 
 | |
| def load_doctors():
 | |
|     return read_json(DOCTORS_FILE, {})
 | |
| 
 | |
| def load_expenses():
 | |
|     return read_json(EXPENSES_FILE, [])
 | |
| 
 | |
| def load_reports():
 | |
|     return read_json(REPORTS_FILE, [])
 | |
| 
 | |
| def audit_list_doctors():
 | |
|     docs = load_doctors()
 | |
|     print("Doctors:")
 | |
|     for did, info in docs.items():
 | |
|         enc = info["dept_enc"]
 | |
|         print(f"- {did} dept_plain='{info['department_plain']}' enc_ciphertext={enc['ciphertext']} exponent={enc['exponent']}")
 | |
| 
 | |
| def audit_keyword_search():
 | |
|     docs = load_doctors()
 | |
|     if not docs:
 | |
|         print("no doctors")
 | |
|         return
 | |
|     q = input("Enter department keyword to search: ").strip()
 | |
|     if not q:
 | |
|         print("empty")
 | |
|         return
 | |
|     # hash to int
 | |
|     import hashlib
 | |
|     h = int.from_bytes(hashlib.sha256(q.encode()).digest(), "big")
 | |
|     pub = PAI_PUB
 | |
|     priv = PAI_PRIV
 | |
|     enc_q = pub.encrypt(h)
 | |
|     print("Matching doctors (using Paillier equality on hashed dept):")
 | |
|     for did, info in docs.items():
 | |
|         enc = info["dept_enc"]
 | |
|         c = int(enc["ciphertext"])
 | |
|         exp = int(enc["exponent"])
 | |
|         enc_doc = paillier.EncryptedNumber(pub, c, exp)
 | |
|         diff = enc_doc - enc_q
 | |
|         dec = priv.decrypt(diff)
 | |
|         match = (dec == 0)
 | |
|         print(f"  {did}: dept_plain='{info['department_plain']}' enc_ciphertext={c} match={match}")
 | |
| 
 | |
| def rsa_homo_decrypt_sum(c_prod_int):
 | |
|     n = RSA_PRIV.n
 | |
|     d = RSA_PRIV.d
 | |
|     g = RSA_HOMO_G
 | |
|     # decrypt to get g^sum mod n
 | |
|     m = pow(int(c_prod_int), d, n)
 | |
|     # brute force discrete log for small sums
 | |
|     max_iter = 500000
 | |
|     acc = 1
 | |
|     for k in range(0, max_iter+1):
 | |
|         if acc == m:
 | |
|             return k
 | |
|         acc = (acc * g) % n
 | |
|     return None
 | |
| 
 | |
| def audit_sum_expenses():
 | |
|     exps = load_expenses()
 | |
|     if not exps:
 | |
|         print("no expenses")
 | |
|         return
 | |
|     # sum all
 | |
|     n = RSA_PUB.n
 | |
|     c_prod = 1
 | |
|     for e in exps:
 | |
|         c_prod = (c_prod * int(e["ciphertext"])) % n
 | |
|     print(f"Product ciphertext (represents sum under RSA-in-exponent): {c_prod}")
 | |
|     s = rsa_homo_decrypt_sum(c_prod)
 | |
|     if s is None:
 | |
|         print("sum decryption failed (exceeded search bound)")
 | |
|     else:
 | |
|         print(f"Decrypted sum of expenses = {s}")
 | |
|     # by doctor
 | |
|     docs = load_doctors()
 | |
|     if docs:
 | |
|         print("Per-doctor sums:")
 | |
|         for did in docs.keys():
 | |
|             c_prod_d = 1
 | |
|             count = 0
 | |
|             for e in exps:
 | |
|                 if e["doctor_id"] == did:
 | |
|                     c_prod_d = (c_prod_d * int(e["ciphertext"])) % n
 | |
|                     count += 1
 | |
|             if count == 0:
 | |
|                 continue
 | |
|             s_d = rsa_homo_decrypt_sum(c_prod_d)
 | |
|             print(f"  {did}: entries={count} product_ct={c_prod_d} sum={s_d}")
 | |
| 
 | |
| def elgamal_verify(p, g, y, H_int, r, s):
 | |
|     # verify: g^H ≡ y^r * r^s (mod p)
 | |
|     return pow(g, H_int, p) == (pow(y, r, p) * pow(r, s, p)) % p
 | |
| 
 | |
| def audit_verify_reports():
 | |
|     records = load_reports()
 | |
|     if not records:
 | |
|         print("no reports")
 | |
|         return
 | |
|     doctors = load_doctors()
 | |
|     for rec in records:
 | |
|         did = rec["doctor_id"]
 | |
|         docinfo = doctors.get(did)
 | |
|         ok_sig = False
 | |
|         ok_ts = False
 | |
|         if docinfo:
 | |
|             p = int(docinfo["elgamal_pub"]["p"])
 | |
|             g = int(docinfo["elgamal_pub"]["g"])
 | |
|             y = int(docinfo["elgamal_pub"]["y"])
 | |
|             r = int(rec["sig"]["r"])
 | |
|             s = int(rec["sig"]["s"])
 | |
|             try:
 | |
|                 with open(rec["saved_path"], "rb") as f:
 | |
|                     report_bytes = f.read()
 | |
|                 import hashlib
 | |
|                 H = int.from_bytes(hashlib.md5(report_bytes + rec["timestamp"].encode()).digest(), "big") % (p - 1)
 | |
|                 ok_sig = elgamal_verify(p, g, y, H, r, s)
 | |
|             except Exception as e:
 | |
|                 ok_sig = False
 | |
|         # timestamp check
 | |
|         try:
 | |
|             ts = datetime.fromisoformat(rec["timestamp"])
 | |
|         except:
 | |
|             try:
 | |
|                 ts = datetime.strptime(rec["timestamp"], "%Y-%m-%dT%H:%M:%S.%f")
 | |
|             except:
 | |
|                 ts = None
 | |
|         if ts:
 | |
|             now = datetime.utcnow().replace(tzinfo=None)
 | |
|             delta = (now - ts).total_seconds()
 | |
|             # simple rule: not in the future by more than 5 min
 | |
|             ok_ts = (delta >= -300)
 | |
|         print(f"- report by {did} file={os.path.basename(rec['saved_path'])} sig_ok={ok_sig} ts_ok={ok_ts} ts={rec['timestamp']} md5={rec['md5_hex']}")
 | |
| 
 | |
| def auditor_menu():
 | |
|     while True:
 | |
|         print("\n[Auditor Menu]")
 | |
|         print("1) List doctors (show encrypted and plaintext dept)")
 | |
|         print("2) Keyword search doctors by dept (Paillier)")
 | |
|         print("3) Sum expenses (RSA-in-exponent demo)")
 | |
|         print("4) Verify reports and timestamps")
 | |
|         print("5) Show server public info")
 | |
|         print("0) Exit")
 | |
|         ch = input("Select: ").strip()
 | |
|         if ch == "1":
 | |
|             audit_list_doctors()
 | |
|         elif ch == "2":
 | |
|             audit_keyword_search()
 | |
|         elif ch == "3":
 | |
|             audit_sum_expenses()
 | |
|         elif ch == "4":
 | |
|             audit_verify_reports()
 | |
|         elif ch == "5":
 | |
|             info = get_public_info()
 | |
|             print(json.dumps(info, indent=2))
 | |
|         elif ch == "0":
 | |
|             print("bye")
 | |
|             break
 | |
|         else:
 | |
|             print("invalid")
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     start_server()
 | |
|     auditor_menu()
 | |
| 
 | 
