Browse Source

now clients can disconnect at any time

Simon 2 years ago
parent
commit
880617e7e5
2 changed files with 54 additions and 67 deletions
  1. 4 14
      follower/follower.go
  2. 50 53
      leader/leader.go

+ 4 - 14
follower/follower.go

@@ -218,10 +218,8 @@ func main() {
 
 func phase1(id int, leaderWorkerConnection net.Conn, m *sync.RWMutex, wg *sync.WaitGroup, virtualAddresses []int) {
 
-	gotClient := make([]byte, 1)
-
 	for {
-		gotClient = readFrom(leaderWorkerConnection, 1)
+		gotClient := readFrom(leaderWorkerConnection, 1)
 
 		//this worker is done
 		if gotClient[0] == 0 {
@@ -267,9 +265,7 @@ func phase1(id int, leaderWorkerConnection net.Conn, m *sync.RWMutex, wg *sync.W
 		}
 		dpfLength := byteToInt(dpfLengthBytes)
 
-		dpfQueryBEncrypted := make([]byte, dpfLength)
-
-		dpfQueryBEncrypted, errorBool = readFromWError(leaderWorkerConnection, dpfLength)
+		dpfQueryBEncrypted, errorBool := readFromWError(leaderWorkerConnection, dpfLength)
 		if errorBool {
 			continue
 		}
@@ -287,11 +283,9 @@ func phase1(id int, leaderWorkerConnection net.Conn, m *sync.RWMutex, wg *sync.W
 		pos := C.getUint128_t(C.int(virtualAddresses[dbWriteSize]))
 		C.evalDPF(C.ctx[id], (*C.uchar)(&dpfQueryB[0]), pos, C.int(ds), (*C.uchar)(&dataShareFollower[0]))
 
-		dataShareLeader := make([]byte, ds)
-
 		writeTo(leaderWorkerConnection, dataShareFollower)
 
-		dataShareLeader, errorBool = readFromWError(leaderWorkerConnection, ds)
+		dataShareLeader, errorBool := readFromWError(leaderWorkerConnection, ds)
 		if errorBool {
 			continue
 		}
@@ -570,9 +564,6 @@ func getSendTweets(clientKeys clientKeys, archiveQuerys [][]byte, leaderWorkerCo
 //returns true if client connection is lost
 func handlePirQuery(clientKeys clientKeys, leaderWorkerConnection net.Conn, subPhase int, clientPublicKey [32]byte, doAuditing bool) (clientKeys, [][]byte, bool) {
 
-	test := readFrom(leaderWorkerConnection, 1)
-	fmt.Println("test", test)
-
 	archiveNeededSubscriptions := make([]byte, 4)
 	if subPhase == -1 {
 		archiveNeededSubscriptions, errorBool := readFromWError(leaderWorkerConnection, 4)
@@ -703,8 +694,7 @@ func readFromWError(connection net.Conn, size int) ([]byte, bool) {
 	if array[0] == 1 {
 		return nil, true
 	}
-	fmt.Println(array)
-	return array, false
+	return array[1:], false
 }
 
 func transformBytesToStringArray(topicsAsBytes []byte) []string {

+ 50 - 53
leader/leader.go

@@ -94,16 +94,6 @@ func main() {
 	leaderPrivateKey = generatedPrivateKey
 	leaderPublicKey = generatedPublicKey
 
-	/*
-		if len(os.Args) != 4 {
-			fmt.Println("try again with: numThreads, dataLength, numRows")
-			return
-		}
-		numThreads, _ = strconv.Atoi(os.Args[2])
-		dataLength, _ = strconv.Atoi(os.Args[3])
-		numRows, _ = strconv.Atoi(os.Args[4])
-	*/
-
 	C.initializeServer(C.int(numThreads))
 
 	//calls follower for setup
@@ -125,7 +115,7 @@ func main() {
 	followerPublicKey = &tmpFollowerPubKey
 
 	//send publicKey to follower
-	writeTo(followerConnection, leaderPublicKey[:], nil, 0)
+	writeTo(followerConnection, leaderPublicKey[:])
 
 	//goroutine for accepting new clients
 	go func() {
@@ -255,7 +245,7 @@ func main() {
 		virtualAddresses := createVirtualAddresses()
 		//send all virtualAddresses to follower
 		for i := 0; i <= dbWriteSize; i++ {
-			writeTo(followerConnection, intToByte(virtualAddresses[i]), nil, 0)
+			writeTo(followerConnection, intToByte(virtualAddresses[i]))
 		}
 
 		for id := 0; id < numThreads; id++ {
@@ -326,7 +316,7 @@ func phase1(id int, phase []byte, followerConnection net.Conn, wg *sync.WaitGrou
 	for len(phase1Channel) == 0 {
 		if time.Since(startTime) > maxTimePerRound {
 			//tells follower that this worker is done
-			writeTo(followerConnection, gotClient, nil, 0)
+			writeTo(followerConnection, gotClient)
 			wg.Done()
 			return
 		}
@@ -337,13 +327,13 @@ func phase1(id int, phase []byte, followerConnection net.Conn, wg *sync.WaitGrou
 
 		gotClient[0] = 1
 		//tells follower that this worker got a clientConnection
-		writeTo(followerConnection, gotClient, nil, 0)
+		writeTo(followerConnection, gotClient)
 
 		//sends clients publicKey to follower
 		m.RLock()
 		clientPublicKey := clientData[clientConnection.RemoteAddr()].PublicKey
 		m.RUnlock()
-		writeTo(followerConnection, clientPublicKey[:], nil, 0)
+		writeTo(followerConnection, clientPublicKey[:])
 
 		//setup the worker-specific db
 		dbSize := int(C.dbSize)
@@ -353,7 +343,7 @@ func phase1(id int, phase []byte, followerConnection net.Conn, wg *sync.WaitGrou
 		}
 
 		//tells client that phase 1 has begun
-		errorBool := writeTo(clientConnection, phase, followerConnection, 5)
+		errorBool := writeToWError(clientConnection, phase, followerConnection, 5)
 		if errorBool {
 			contBool := handleClientDC(wg, followerConnection, phase1Channel)
 			if contBool {
@@ -364,7 +354,7 @@ func phase1(id int, phase []byte, followerConnection net.Conn, wg *sync.WaitGrou
 		}
 
 		//tells client current dbWriteSize
-		errorBool = writeTo(clientConnection, intToByte(dbWriteSize), followerConnection, 5)
+		errorBool = writeToWError(clientConnection, intToByte(dbWriteSize), followerConnection, 5)
 		if errorBool {
 			contBool := handleClientDC(wg, followerConnection, phase1Channel)
 			if contBool {
@@ -375,7 +365,7 @@ func phase1(id int, phase []byte, followerConnection net.Conn, wg *sync.WaitGrou
 		}
 
 		//tells client current round
-		errorBool = writeTo(clientConnection, roundAsBytes, followerConnection, 5)
+		errorBool = writeToWError(clientConnection, roundAsBytes, followerConnection, 5)
 		if errorBool {
 			contBool := handleClientDC(wg, followerConnection, phase1Channel)
 			if contBool {
@@ -444,10 +434,9 @@ func phase1(id int, phase []byte, followerConnection net.Conn, wg *sync.WaitGrou
 			}
 		}
 
-		//handledc
-		writeTo(followerConnection, dpfLengthBytes, nil, 0)
+		writeToWError(followerConnection, dpfLengthBytes, nil, 0)
 
-		writeTo(followerConnection, dpfQueryBEncrypted, nil, 0)
+		writeToWError(followerConnection, dpfQueryBEncrypted, nil, 0)
 
 		//decrypt dpfQueryA for sorting into db
 		var decryptNonce [24]byte
@@ -464,7 +453,7 @@ func phase1(id int, phase []byte, followerConnection net.Conn, wg *sync.WaitGrou
 
 		dataShareFollower, _ := readFrom(followerConnection, ds, nil, 0)
 
-		writeTo(followerConnection, dataShareLeader, nil, 0)
+		writeToWError(followerConnection, dataShareLeader, nil, 0)
 
 		auditXOR := make([]byte, ds)
 		passedAudit := true
@@ -514,7 +503,7 @@ func phase1(id int, phase []byte, followerConnection net.Conn, wg *sync.WaitGrou
 				//tells follower that this worker is done
 				gotClient[0] = 0
 
-				writeTo(followerConnection, gotClient, nil, 0)
+				writeTo(followerConnection, gotClient)
 
 				wg.Done()
 				return
@@ -540,12 +529,12 @@ func phase2(followerConnection net.Conn) {
 	}
 
 	//writes seed to follower
-	writeTo(followerConnection, seedLeader, nil, 0)
+	writeTo(followerConnection, seedLeader)
 
 	//write data to follower
 	//this is surely inefficent
 	for i := 0; i < dbSize; i++ {
-		writeTo(followerConnection, tmpdbLeader[i], nil, 0)
+		writeTo(followerConnection, tmpdbLeader[i])
 	}
 
 	//receive seed from follower
@@ -577,7 +566,7 @@ func phase2(followerConnection net.Conn) {
 
 	//send own Ciphers to follower
 	for i := 0; i < dbSize; i++ {
-		writeTo(followerConnection, C.GoBytes(unsafe.Pointer(ciphersLeader[i]), 16), nil, 0)
+		writeTo(followerConnection, C.GoBytes(unsafe.Pointer(ciphersLeader[i]), 16))
 	}
 
 	//receive ciphers from follower
@@ -666,7 +655,7 @@ func phase2(followerConnection net.Conn) {
 	dbWriteSize = int(math.Ceil(19.5 * float64(publisherAverage)))
 
 	//writes dbWriteSize of current round to follower
-	writeTo(followerConnection, intToByte(dbWriteSize), nil, 0)
+	writeTo(followerConnection, intToByte(dbWriteSize))
 }
 
 func addTestTweets() {
@@ -697,7 +686,7 @@ func phase3(id int, phase []byte, followerConnection net.Conn, wg *sync.WaitGrou
 	for len(phase3Channel) == 0 {
 		if time.Since(startTime) > maxTimePerRound {
 			//tells follower that this worker is done
-			writeTo(followerConnection, gotClient, nil, 0)
+			writeToWError(followerConnection, gotClient, nil, 0)
 			wg.Done()
 			return
 		}
@@ -708,10 +697,10 @@ func phase3(id int, phase []byte, followerConnection net.Conn, wg *sync.WaitGrou
 
 		gotClient[0] = 1
 		//tells follower that this worker got a clientConnection
-		writeTo(followerConnection, gotClient, nil, 0)
+		writeToWError(followerConnection, gotClient, nil, 0)
 
 		//tells client current phase
-		errorBool := writeTo(clientConnection, phase, followerConnection, 2)
+		errorBool := writeToWError(clientConnection, phase, followerConnection, 2)
 		if errorBool {
 			contBool := handleClientDC(wg, followerConnection, phase3Channel)
 			if contBool {
@@ -748,7 +737,7 @@ func phase3(id int, phase []byte, followerConnection net.Conn, wg *sync.WaitGrou
 		}
 
 		//tells client what leader expects
-		errorBool = writeTo(clientConnection, subPhase, followerConnection, 2)
+		errorBool = writeToWError(clientConnection, subPhase, followerConnection, 2)
 		if errorBool {
 			contBool := handleClientDC(wg, followerConnection, phase3Channel)
 			if contBool {
@@ -758,12 +747,11 @@ func phase3(id int, phase []byte, followerConnection net.Conn, wg *sync.WaitGrou
 			}
 		}
 
-		//handledc switch order
 		//tells follower what will happen
-		writeTo(followerConnection, subPhase, nil, 0)
+		writeToWError(followerConnection, subPhase, nil, 0)
 
 		//sends clients publicKey so follower knows which client is being served
-		writeTo(followerConnection, clientKeys.PublicKey[:], nil, 0)
+		writeTo(followerConnection, clientKeys.PublicKey[:])
 
 		//increases rounds participating for client
 		clientKeys.roundsParticipating = roundsParticipating + 1
@@ -837,7 +825,7 @@ func phase3(id int, phase []byte, followerConnection net.Conn, wg *sync.WaitGrou
 			}
 		}
 
-		writeTo(followerConnection, wantsArchive, nil, 0)
+		writeToWError(followerConnection, wantsArchive, nil, 0)
 
 		if wantsArchive[0] == 1 && archiveTopicAmount > 0 {
 			_, archiveQuerys, errorBool := handlePirQuery(clientKeys, clientConnection, followerConnection, -1, false)
@@ -880,7 +868,7 @@ func phase3(id int, phase []byte, followerConnection net.Conn, wg *sync.WaitGrou
 			} else {
 				//tells follower that this worker is done
 				gotClient[0] = 0
-				writeTo(followerConnection, gotClient, nil, 0)
+				writeToWError(followerConnection, gotClient, nil, 0)
 				wg.Done()
 				return
 			}
@@ -906,7 +894,7 @@ func handleClientDC(wg *sync.WaitGroup, followerConnection net.Conn, channel cha
 			gotClient := make([]byte, 1)
 			gotClient[0] = 0
 
-			writeTo(followerConnection, gotClient, nil, 0)
+			writeTo(followerConnection, gotClient)
 
 			wg.Done()
 			return false
@@ -960,7 +948,7 @@ func getSendVirtualAddress(pirQuery []byte, virtualAddresses []int, sharedSecret
 		virtualAddress[i] = virtualAddress[i] ^ virtualAddressFollower[i]
 	}
 
-	errorBool := writeTo(clientConnection, virtualAddress, followerConnection, 5)
+	errorBool := writeToWError(clientConnection, virtualAddress, followerConnection, 5)
 
 	return errorBool
 }
@@ -980,7 +968,7 @@ func getSendTweets(clientKeys clientKeys, archiveQuerys [][]byte, clientConnecti
 		}
 
 		//expand sharedSecret so it is of right length
-		expandBy := len(tweets) / 32
+		expandBy := len(tweets[i]) / 32
 		var expandedSharedSecret []byte
 		for i := 0; i < expandBy; i++ {
 			expandedSharedSecret = append(expandedSharedSecret, clientKeys.SharedSecret[:]...)
@@ -1002,11 +990,11 @@ func getSendTweets(clientKeys clientKeys, archiveQuerys [][]byte, clientConnecti
 	//sends tweets to client
 	for i := 0; i < tmpNeededSubscriptions; i++ {
 		tweetsLengthBytes := intToByte(len(tweets[i]))
-		errorBool := writeTo(clientConnection, tweetsLengthBytes, followerConnection, 2)
+		errorBool := writeToWError(clientConnection, tweetsLengthBytes, followerConnection, 2)
 		if errorBool {
 			return true
 		}
-		errorBool = writeTo(clientConnection, tweets[i], followerConnection, 2)
+		errorBool = writeToWError(clientConnection, tweets[i], followerConnection, 2)
 		if errorBool {
 			return true
 		}
@@ -1018,10 +1006,6 @@ func handlePirQuery(clientKeys clientKeys, clientConnection net.Conn, followerCo
 
 	clientPublicKey := clientKeys.PublicKey
 
-	test := make([]byte, 1)
-	test[0] = 1
-	followerConnection.Write(test)
-
 	//gets the msg length
 	msgLengthBytes, errorBool := readFrom(clientConnection, 4, followerConnection, 5)
 	if errorBool {
@@ -1049,7 +1033,7 @@ func handlePirQuery(clientKeys clientKeys, clientConnection net.Conn, followerCo
 			return clientKeys, nil, true
 		}
 
-		writeTo(followerConnection, archiveNeededSubscriptions, nil, 0)
+		writeToWError(followerConnection, archiveNeededSubscriptions, nil, 0)
 		tmpNeededSubscriptions = byteToInt(archiveNeededSubscriptions)
 		tmpTopicAmount = archiveTopicAmount
 	}
@@ -1059,11 +1043,10 @@ func handlePirQuery(clientKeys clientKeys, clientConnection net.Conn, followerCo
 	}
 
 	//send length to follower
-	fmt.Println(msgLengthBytes)
-	writeTo(followerConnection, msgLengthBytes, nil, 0)
+	writeToWError(followerConnection, msgLengthBytes, nil, 0)
 
 	//send box to follower
-	writeTo(followerConnection, followerBox, nil, 0)
+	writeToWError(followerConnection, followerBox, nil, 0)
 
 	var decryptNonce [24]byte
 	copy(decryptNonce[:], leaderBox[:24])
@@ -1139,11 +1122,11 @@ func sendTopicLists(clientConnection, followerConnection net.Conn, setup bool) b
 		topicListLengthBytes := intToByte(len(topicList))
 
 		if !setup {
-			err := writeTo(clientConnection, topicListLengthBytes, followerConnection, 5)
+			err := writeToWError(clientConnection, topicListLengthBytes, followerConnection, 5)
 			if err {
 				return true
 			}
-			err = writeTo(clientConnection, topicList, followerConnection, 5)
+			err = writeToWError(clientConnection, topicList, followerConnection, 5)
 			if err {
 				return true
 			}
@@ -1162,9 +1145,23 @@ func sendTopicLists(clientConnection, followerConnection net.Conn, setup bool) b
 }
 
 //sends the array to the connection
-//todo! need to split into WErro and WOError
-func writeTo(connection net.Conn, array []byte, followerConnection net.Conn, size int) bool {
+func writeTo(connection net.Conn, array []byte) {
 	_, err := connection.Write(array)
+	if err != nil {
+		panic(err)
+	}
+}
+
+func writeToWError(connection net.Conn, array []byte, followerConnection net.Conn, size int) bool {
+	var err error
+	if connection.RemoteAddr().String() == follower {
+		arrayWError := make([]byte, 1)
+		arrayWError = append(arrayWError, array[:]...)
+		_, err = connection.Write(arrayWError)
+	} else {
+		_, err = connection.Write(array)
+	}
+
 	if err != nil {
 		//lets follower know that client has disconnected unexpectedly
 		if connection.RemoteAddr().String() != follower {