From e6f6ab499a8cb1b8cd4c18d395a620ac7c85697a Mon Sep 17 00:00:00 2001 From: Koen <koen@futo.org> Date: Mon, 26 Aug 2024 13:29:57 +0000 Subject: [PATCH] Custom mDNS implementation for faster discovery. --- app/build.gradle | 9 +- .../api/media/models/chapters/IChapter.kt | 2 +- .../api/media/models/contents/ContentType.kt | 2 +- .../api/media/models/live/LiveEventType.kt | 2 +- .../api/media/models/post/TextType.kt | 2 +- .../api/media/models/ratings/RatingType.kt | 2 +- .../platformplayer/casting/StateCasting.kt | 178 ++---- .../dialogs/ConnectCastingDialog.kt | 3 + .../mainactivity/main/VideoDetailFragment.kt | 2 +- .../platformplayer/mdns/BroadcastService.kt | 11 + .../com/futo/platformplayer/mdns/DnsPacket.kt | 93 ++++ .../futo/platformplayer/mdns/DnsQuestion.kt | 110 ++++ .../com/futo/platformplayer/mdns/DnsReader.kt | 514 ++++++++++++++++++ .../platformplayer/mdns/DnsResourceRecord.kt | 117 ++++ .../com/futo/platformplayer/mdns/DnsWriter.kt | 208 +++++++ .../futo/platformplayer/mdns/Extensions.kt | 63 +++ .../futo/platformplayer/mdns/MDNSListener.kt | 482 ++++++++++++++++ .../futo/platformplayer/mdns/NICMonitor.kt | 66 +++ .../platformplayer/mdns/ServiceDiscoverer.kt | 68 +++ .../mdns/ServiceRecordAggregator.kt | 219 ++++++++ .../futo/platformplayer/views/FeedStyle.kt | 2 +- .../main/res/drawable/battery_full_24px.xml | 9 + .../java/com/futo/platformplayer/MdnsTests.kt | 394 ++++++++++++++ app/src/test/resources/samsung-airplay.hex | Bin 0 -> 642 bytes 24 files changed, 2423 insertions(+), 135 deletions(-) create mode 100644 app/src/main/java/com/futo/platformplayer/mdns/BroadcastService.kt create mode 100644 app/src/main/java/com/futo/platformplayer/mdns/DnsPacket.kt create mode 100644 app/src/main/java/com/futo/platformplayer/mdns/DnsQuestion.kt create mode 100644 app/src/main/java/com/futo/platformplayer/mdns/DnsReader.kt create mode 100644 app/src/main/java/com/futo/platformplayer/mdns/DnsResourceRecord.kt create mode 100644 app/src/main/java/com/futo/platformplayer/mdns/DnsWriter.kt create mode 100644 app/src/main/java/com/futo/platformplayer/mdns/Extensions.kt create mode 100644 app/src/main/java/com/futo/platformplayer/mdns/MDNSListener.kt create mode 100644 app/src/main/java/com/futo/platformplayer/mdns/NICMonitor.kt create mode 100644 app/src/main/java/com/futo/platformplayer/mdns/ServiceDiscoverer.kt create mode 100644 app/src/main/java/com/futo/platformplayer/mdns/ServiceRecordAggregator.kt create mode 100644 app/src/main/res/drawable/battery_full_24px.xml create mode 100644 app/src/test/java/com/futo/platformplayer/MdnsTests.kt create mode 100644 app/src/test/resources/samsung-airplay.hex diff --git a/app/build.gradle b/app/build.gradle index a3ddf53e..476fc35f 100644 --- a/app/build.gradle +++ b/app/build.gradle @@ -144,10 +144,18 @@ android { buildFeatures { buildConfig true } + sourceSets { + main { + assets { + srcDirs 'src/main/assets', 'src/tests/assets', 'src/test/assets' + } + } + } } dependencies { implementation 'com.google.dagger:dagger:2.48' + implementation 'androidx.test:monitor:1.7.2' annotationProcessor 'com.google.dagger:dagger-compiler:2.48' //Core @@ -186,7 +194,6 @@ dependencies { implementation 'androidx.media:media:1.7.0' //Other - implementation 'org.jmdns:jmdns:3.5.1' implementation 'org.jsoup:jsoup:1.15.3' implementation 'com.google.android.flexbox:flexbox:3.0.0' implementation 'androidx.swiperefreshlayout:swiperefreshlayout:1.1.0' diff --git a/app/src/main/java/com/futo/platformplayer/api/media/models/chapters/IChapter.kt b/app/src/main/java/com/futo/platformplayer/api/media/models/chapters/IChapter.kt index 5d7a7c7f..4ed5add4 100644 --- a/app/src/main/java/com/futo/platformplayer/api/media/models/chapters/IChapter.kt +++ b/app/src/main/java/com/futo/platformplayer/api/media/models/chapters/IChapter.kt @@ -23,7 +23,7 @@ enum class ChapterType(val value: Int) { companion object { fun fromInt(value: Int): ChapterType { - val result = ChapterType.values().firstOrNull { it.value == value }; + val result = ChapterType.entries.firstOrNull { it.value == value }; if(result == null) throw UnknownPlatformException(value.toString()); return result; diff --git a/app/src/main/java/com/futo/platformplayer/api/media/models/contents/ContentType.kt b/app/src/main/java/com/futo/platformplayer/api/media/models/contents/ContentType.kt index 27d51abe..a310e089 100644 --- a/app/src/main/java/com/futo/platformplayer/api/media/models/contents/ContentType.kt +++ b/app/src/main/java/com/futo/platformplayer/api/media/models/contents/ContentType.kt @@ -21,7 +21,7 @@ enum class ContentType(val value: Int) { companion object { fun fromInt(value: Int): ContentType { - val result = ContentType.values().firstOrNull { it.value == value }; + val result = ContentType.entries.firstOrNull { it.value == value }; if(result == null) throw UnknownPlatformException(value.toString()); return result; diff --git a/app/src/main/java/com/futo/platformplayer/api/media/models/live/LiveEventType.kt b/app/src/main/java/com/futo/platformplayer/api/media/models/live/LiveEventType.kt index ddec2e0a..c2f9bc4a 100644 --- a/app/src/main/java/com/futo/platformplayer/api/media/models/live/LiveEventType.kt +++ b/app/src/main/java/com/futo/platformplayer/api/media/models/live/LiveEventType.kt @@ -10,7 +10,7 @@ enum class LiveEventType(val value : Int) { companion object{ fun fromInt(value : Int) : LiveEventType{ - return LiveEventType.values().first { it.value == value }; + return LiveEventType.entries.first { it.value == value }; } } } \ No newline at end of file diff --git a/app/src/main/java/com/futo/platformplayer/api/media/models/post/TextType.kt b/app/src/main/java/com/futo/platformplayer/api/media/models/post/TextType.kt index 8a2c20d4..c1de57d1 100644 --- a/app/src/main/java/com/futo/platformplayer/api/media/models/post/TextType.kt +++ b/app/src/main/java/com/futo/platformplayer/api/media/models/post/TextType.kt @@ -10,7 +10,7 @@ enum class TextType(val value: Int) { companion object { fun fromInt(value: Int): TextType { - val result = TextType.values().firstOrNull { it.value == value }; + val result = TextType.entries.firstOrNull { it.value == value }; if(result == null) throw IllegalArgumentException("Unknown Texttype: $value"); return result; diff --git a/app/src/main/java/com/futo/platformplayer/api/media/models/ratings/RatingType.kt b/app/src/main/java/com/futo/platformplayer/api/media/models/ratings/RatingType.kt index eba21430..956b9d31 100644 --- a/app/src/main/java/com/futo/platformplayer/api/media/models/ratings/RatingType.kt +++ b/app/src/main/java/com/futo/platformplayer/api/media/models/ratings/RatingType.kt @@ -8,7 +8,7 @@ enum class RatingType(val value : Int) { companion object{ fun fromInt(value : Int) : RatingType{ - return RatingType.values().first { it.value == value }; + return RatingType.entries.first { it.value == value }; } } } \ No newline at end of file diff --git a/app/src/main/java/com/futo/platformplayer/casting/StateCasting.kt b/app/src/main/java/com/futo/platformplayer/casting/StateCasting.kt index 5e99462e..c0301d40 100644 --- a/app/src/main/java/com/futo/platformplayer/casting/StateCasting.kt +++ b/app/src/main/java/com/futo/platformplayer/casting/StateCasting.kt @@ -31,6 +31,8 @@ import com.futo.platformplayer.constructs.Event1 import com.futo.platformplayer.constructs.Event2 import com.futo.platformplayer.exceptions.UnsupportedCastException import com.futo.platformplayer.logging.Logger +import com.futo.platformplayer.mdns.DnsService +import com.futo.platformplayer.mdns.ServiceDiscoverer import com.futo.platformplayer.models.CastingDeviceInfo import com.futo.platformplayer.parsers.HLS import com.futo.platformplayer.states.StateApp @@ -45,15 +47,10 @@ import kotlinx.serialization.Serializable import kotlinx.serialization.json.Json import java.net.InetAddress import java.util.UUID -import javax.jmdns.JmDNS -import javax.jmdns.ServiceEvent -import javax.jmdns.ServiceListener -import javax.jmdns.ServiceTypeListener class StateCasting { private val _scopeIO = CoroutineScope(Dispatchers.IO); private val _scopeMain = CoroutineScope(Dispatchers.Main); - private var _jmDNS: JmDNS? = null; private val _storage: CastingDeviceInfoStorage = FragmentedStorage.get(); private val _castServer = ManagedHttpServer(9999); @@ -72,102 +69,46 @@ class StateCasting { var activeDevice: CastingDevice? = null; private val _client = ManagedHttpClient(); var _resumeCastingDevice: CastingDeviceInfo? = null; + val _serviceDiscoverer = ServiceDiscoverer(arrayOf( + "_googlecast._tcp.local", + "_airplay._tcp.local", + "_fastcast._tcp.local", + "_fcast._tcp.local" + )) { handleServiceUpdated(it) } val isCasting: Boolean get() = activeDevice != null; - private val _chromecastServiceListener = object : ServiceListener { - override fun serviceAdded(event: ServiceEvent) { - Logger.i(TAG, "ChromeCast service added: " + event.info); - addOrUpdateDevice(event); - } - - override fun serviceRemoved(event: ServiceEvent) { - Logger.i(TAG, "ChromeCast service removed: " + event.info); - synchronized(devices) { - val device = devices[event.info.name]; - if (device != null) { - onDeviceRemoved.emit(device); + private fun handleServiceUpdated(services: List<DnsService>) { + for (s in services) { + //TODO: Addresses IPv4 only? + val addresses = s.addresses.toTypedArray() + val port = s.port.toInt() + var name = s.texts.firstOrNull { it.startsWith("md=") }?.substring("md=".length) + if (s.name.endsWith("._googlecast._tcp.local")) { + if (name == null) { + name = s.name.substring(0, s.name.length - "._googlecast._tcp.local".length) } - } - } - override fun serviceResolved(event: ServiceEvent) { - Logger.v(TAG, "ChromeCast service resolved: " + event.info); - addOrUpdateDevice(event); - } - - fun addOrUpdateDevice(event: ServiceEvent) { - addOrUpdateChromeCastDevice(event.info.name, event.info.inetAddresses, event.info.port); - } - } - - private val _airPlayServiceListener = object : ServiceListener { - override fun serviceAdded(event: ServiceEvent) { - Logger.i(TAG, "AirPlay service added: " + event.info); - addOrUpdateDevice(event); - } - - override fun serviceRemoved(event: ServiceEvent) { - Logger.i(TAG, "AirPlay service removed: " + event.info); - synchronized(devices) { - val device = devices[event.info.name]; - if (device != null) { - onDeviceRemoved.emit(device); + addOrUpdateChromeCastDevice(name, addresses, port) + } else if (s.name.endsWith("._airplay._tcp.local")) { + if (name == null) { + name = s.name.substring(0, s.name.length - "._airplay._tcp.local".length) } - } - } - - override fun serviceResolved(event: ServiceEvent) { - Logger.i(TAG, "AirPlay service resolved: " + event.info); - addOrUpdateDevice(event); - } - fun addOrUpdateDevice(event: ServiceEvent) { - addOrUpdateAirPlayDevice(event.info.name, event.info.inetAddresses, event.info.port); - } - } - - private val _fastCastServiceListener = object : ServiceListener { - override fun serviceAdded(event: ServiceEvent) { - Logger.i(TAG, "FastCast service added: " + event.info); - addOrUpdateDevice(event); - } - - override fun serviceRemoved(event: ServiceEvent) { - Logger.i(TAG, "FastCast service removed: " + event.info); - synchronized(devices) { - val device = devices[event.info.name]; - if (device != null) { - onDeviceRemoved.emit(device); + addOrUpdateAirPlayDevice(name, addresses, port) + } else if (s.name.endsWith("._fastcast._tcp.local")) { + if (name == null) { + name = s.name.substring(0, s.name.length - "._fastcast._tcp.local".length) } - } - } - - override fun serviceResolved(event: ServiceEvent) { - Logger.i(TAG, "FastCast service resolved: " + event.info); - addOrUpdateDevice(event); - } - - fun addOrUpdateDevice(event: ServiceEvent) { - addOrUpdateFastCastDevice(event.info.name, event.info.inetAddresses, event.info.port); - } - } - - private val _serviceTypeListener = object : ServiceTypeListener { - override fun serviceTypeAdded(event: ServiceEvent?) { - if (event == null) { - return; - } - Logger.i(TAG, "Service type added (name: ${event.name}, type: ${event.type})"); - } + addOrUpdateFastCastDevice(name, addresses, port) + } else if (s.name.endsWith("._fcast._tcp.local")) { + if (name == null) { + name = s.name.substring(0, s.name.length - "._fcast._tcp.local".length) + } - override fun subTypeForServiceTypeAdded(event: ServiceEvent?) { - if (event == null) { - return; + addOrUpdateFastCastDevice(name, addresses, port) } - - Logger.i(TAG, "Sub type for service type added (name: ${event.name}, type: ${event.type})"); } } @@ -237,29 +178,30 @@ class StateCasting { rememberedDevices.clear(); rememberedDevices.addAll(_storage.deviceInfos.map { deviceFromCastingDeviceInfo(it) }); - _scopeIO.launch { - try { - val jmDNS = JmDNS.create(InetAddress.getLocalHost()); - jmDNS.addServiceListener("_googlecast._tcp.local.", _chromecastServiceListener); - jmDNS.addServiceListener("_airplay._tcp.local.", _airPlayServiceListener); - jmDNS.addServiceListener("_fastcast._tcp.local.", _fastCastServiceListener); - jmDNS.addServiceListener("_fcast._tcp.local.", _fastCastServiceListener); - - if (BuildConfig.DEBUG) { - jmDNS.addServiceTypeListener(_serviceTypeListener); - } - - _jmDNS = jmDNS; - } catch (e: Throwable) { - Logger.e(TAG, "Failed to start casting service.", e); - } - } _castServer.start(); enableDeveloper(true); Logger.i(TAG, "CastingService started."); } + @Synchronized + fun startDiscovering() { + try { + _serviceDiscoverer.start() + } catch (e: Throwable) { + Logger.i(TAG, "Failed to start ServiceDiscoverer", e) + } + } + + @Synchronized + fun stopDiscovering() { + try { + _serviceDiscoverer.stop() + } catch (e: Throwable) { + Logger.i(TAG, "Failed to stop ServiceDiscoverer", e) + } + } + @Synchronized fun stop() { if (!_started) @@ -269,25 +211,7 @@ class StateCasting { Logger.i(TAG, "CastingService stopping.") - val jmDNS = _jmDNS; - if (jmDNS != null) { - _scopeIO.launch { - try { - jmDNS.removeServiceListener("_googlecast._tcp.local.", _chromecastServiceListener); - jmDNS.removeServiceListener("_airplay._tcp", _airPlayServiceListener); - jmDNS.removeServiceListener("_fastcast._tcp.local.", _fastCastServiceListener); - - if (BuildConfig.DEBUG) { - jmDNS.removeServiceTypeListener(_serviceTypeListener); - } - - jmDNS.close(); - } catch (e: Throwable) { - Logger.e(TAG, "Failed to stop mDNS.", e); - } - } - } - + stopDiscovering() _scopeIO.cancel(); _scopeMain.cancel(); @@ -1245,7 +1169,7 @@ class StateCasting { } } else { val newDevice = deviceFactory(); - devices[name] = newDevice; + this.devices[name] = newDevice; invokeEvents = { onDeviceAdded.emit(newDevice); diff --git a/app/src/main/java/com/futo/platformplayer/dialogs/ConnectCastingDialog.kt b/app/src/main/java/com/futo/platformplayer/dialogs/ConnectCastingDialog.kt index e9e826fe..bd5da2ea 100644 --- a/app/src/main/java/com/futo/platformplayer/dialogs/ConnectCastingDialog.kt +++ b/app/src/main/java/com/futo/platformplayer/dialogs/ConnectCastingDialog.kt @@ -104,6 +104,8 @@ class ConnectCastingDialog(context: Context?) : AlertDialog(context) { super.show(); Logger.i(TAG, "Dialog shown."); + StateCasting.instance.startDiscovering() + (_imageLoader.drawable as Animatable?)?.start(); _devices.clear(); @@ -169,6 +171,7 @@ class ConnectCastingDialog(context: Context?) : AlertDialog(context) { (_imageLoader.drawable as Animatable?)?.stop(); + StateCasting.instance.stopDiscovering() StateCasting.instance.onDeviceAdded.remove(this); StateCasting.instance.onDeviceChanged.remove(this); StateCasting.instance.onDeviceRemoved.remove(this); diff --git a/app/src/main/java/com/futo/platformplayer/fragment/mainactivity/main/VideoDetailFragment.kt b/app/src/main/java/com/futo/platformplayer/fragment/mainactivity/main/VideoDetailFragment.kt index e1f311fb..cbefca3b 100644 --- a/app/src/main/java/com/futo/platformplayer/fragment/mainactivity/main/VideoDetailFragment.kt +++ b/app/src/main/java/com/futo/platformplayer/fragment/mainactivity/main/VideoDetailFragment.kt @@ -157,7 +157,7 @@ class VideoDetailFragment : MainFragment { _viewDetail?.preventPictureInPicture = true; } - fun minimizeVideoDetail(){ + fun minimizeVideoDetail() { _viewDetail?.setFullscreen(false); if(_view != null) _view!!.transitionToStart(); diff --git a/app/src/main/java/com/futo/platformplayer/mdns/BroadcastService.kt b/app/src/main/java/com/futo/platformplayer/mdns/BroadcastService.kt new file mode 100644 index 00000000..ac3c61e0 --- /dev/null +++ b/app/src/main/java/com/futo/platformplayer/mdns/BroadcastService.kt @@ -0,0 +1,11 @@ +package com.futo.platformplayer.mdns + +data class BroadcastService( + val deviceName: String, + val serviceName: String, + val port: UShort, + val ttl: UInt, + val weight: UShort, + val priority: UShort, + val texts: List<String>? = null +) \ No newline at end of file diff --git a/app/src/main/java/com/futo/platformplayer/mdns/DnsPacket.kt b/app/src/main/java/com/futo/platformplayer/mdns/DnsPacket.kt new file mode 100644 index 00000000..2c27edf8 --- /dev/null +++ b/app/src/main/java/com/futo/platformplayer/mdns/DnsPacket.kt @@ -0,0 +1,93 @@ +package com.futo.platformplayer.mdns + +import java.nio.ByteBuffer +import java.nio.ByteOrder + +enum class QueryResponse(val value: Byte) { + Query(0), + Response(1) +} + +enum class DnsOpcode(val value: Byte) { + StandardQuery(0), + InverseQuery(1), + ServerStatusRequest(2) +} + +enum class DnsResponseCode(val value: Byte) { + NoError(0), + FormatError(1), + ServerFailure(2), + NameError(3), + NotImplemented(4), + Refused(5) +} + +data class DnsPacketHeader( + val identifier: UShort, + val queryResponse: Int, + val opcode: Int, + val authoritativeAnswer: Boolean, + val truncated: Boolean, + val recursionDesired: Boolean, + val recursionAvailable: Boolean, + val answerAuthenticated: Boolean, + val nonAuthenticatedData: Boolean, + val responseCode: DnsResponseCode +) + +data class DnsPacket( + val header: DnsPacketHeader, + val questions: List<DnsQuestion>, + val answers: List<DnsResourceRecord>, + val authorities: List<DnsResourceRecord>, + val additionals: List<DnsResourceRecord> +) { + companion object { + fun parse(data: ByteArray): DnsPacket { + val span = data.asUByteArray() + val flags = (span[2].toInt() shl 8 or span[3].toInt()).toUShort() + val questionCount = (span[4].toInt() shl 8 or span[5].toInt()).toUShort() + val answerCount = (span[6].toInt() shl 8 or span[7].toInt()).toUShort() + val authorityCount = (span[8].toInt() shl 8 or span[9].toInt()).toUShort() + val additionalCount = (span[10].toInt() shl 8 or span[11].toInt()).toUShort() + + var position = 12 + + val questions = List(questionCount.toInt()) { + DnsQuestion.parse(data, position).also { position = it.second } + }.map { it.first } + + val answers = List(answerCount.toInt()) { + DnsResourceRecord.parse(data, position).also { position = it.second } + }.map { it.first } + + val authorities = List(authorityCount.toInt()) { + DnsResourceRecord.parse(data, position).also { position = it.second } + }.map { it.first } + + val additionals = List(additionalCount.toInt()) { + DnsResourceRecord.parse(data, position).also { position = it.second } + }.map { it.first } + + return DnsPacket( + header = DnsPacketHeader( + identifier = (span[0].toInt() shl 8 or span[1].toInt()).toUShort(), + queryResponse = ((flags.toUInt() shr 15) and 0b1u).toInt(), + opcode = ((flags.toUInt() shr 11) and 0b1111u).toInt(), + authoritativeAnswer = (flags.toInt() shr 10) and 0b1 != 0, + truncated = (flags.toInt() shr 9) and 0b1 != 0, + recursionDesired = (flags.toInt() shr 8) and 0b1 != 0, + recursionAvailable = (flags.toInt() shr 7) and 0b1 != 0, + answerAuthenticated = (flags.toInt() shr 5) and 0b1 != 0, + nonAuthenticatedData = (flags.toInt() shr 4) and 0b1 != 0, + responseCode = DnsResponseCode.entries[flags.toInt() and 0b1111] + ), + questions = questions, + answers = answers, + authorities = authorities, + additionals = additionals + ) + } + } +} \ No newline at end of file diff --git a/app/src/main/java/com/futo/platformplayer/mdns/DnsQuestion.kt b/app/src/main/java/com/futo/platformplayer/mdns/DnsQuestion.kt new file mode 100644 index 00000000..01a7bd77 --- /dev/null +++ b/app/src/main/java/com/futo/platformplayer/mdns/DnsQuestion.kt @@ -0,0 +1,110 @@ +package com.futo.platformplayer.mdns + +import com.futo.platformplayer.mdns.Extensions.readDomainName +import java.nio.ByteBuffer +import java.nio.ByteOrder + + +enum class QuestionType(val value: UShort) { + A(1u), + NS(2u), + MD(3u), + MF(4u), + CNAME(5u), + SOA(6u), + MB(7u), + MG(8u), + MR(9u), + NULL(10u), + WKS(11u), + PTR(12u), + HINFO(13u), + MINFO(14u), + MX(15u), + TXT(16u), + RP(17u), + AFSDB(18u), + SIG(24u), + KEY(25u), + AAAA(28u), + LOC(29u), + SRV(33u), + NAPTR(35u), + KX(36u), + CERT(37u), + DNAME(39u), + APL(42u), + DS(43u), + SSHFP(44u), + IPSECKEY(45u), + RRSIG(46u), + NSEC(47u), + DNSKEY(48u), + DHCID(49u), + NSEC3(50u), + NSEC3PARAM(51u), + TSLA(52u), + SMIMEA(53u), + HIP(55u), + CDS(59u), + CDNSKEY(60u), + OPENPGPKEY(61u), + CSYNC(62u), + ZONEMD(63u), + SVCB(64u), + HTTPS(65u), + EUI48(108u), + EUI64(109u), + TKEY(249u), + TSIG(250u), + URI(256u), + CAA(257u), + TA(32768u), + DLV(32769u), + AXFR(252u), + IXFR(251u), + OPT(41u), + MAILB(253u), + MALA(254u), + All(252u) +} + +enum class QuestionClass(val value: UShort) { + IN(1u), + CS(2u), + CH(3u), + HS(4u), + All(255u) +} + +data class DnsQuestion( + override val name: String, + override val type: Int, + override val clazz: Int, + val queryUnicast: Boolean +) : DnsResourceRecordBase(name, type, clazz) { + companion object { + fun parse(data: ByteArray, startPosition: Int): Pair<DnsQuestion, Int> { + val span = data.asUByteArray() + var position = startPosition + val qname = span.readDomainName(position).also { position = it.second } + val qtype = (span[position].toInt() shl 8 or span[position + 1].toInt()).toUShort() + position += 2 + val qclass = (span[position].toInt() shl 8 or span[position + 1].toInt()).toUShort() + position += 2 + + return DnsQuestion( + name = qname.first, + type = qtype.toInt(), + queryUnicast = ((qclass.toInt() shr 15) and 0b1) != 0, + clazz = qclass.toInt() and 0b111111111111111 + ) to position + } + } +} + +open class DnsResourceRecordBase( + open val name: String, + open val type: Int, + open val clazz: Int +) diff --git a/app/src/main/java/com/futo/platformplayer/mdns/DnsReader.kt b/app/src/main/java/com/futo/platformplayer/mdns/DnsReader.kt new file mode 100644 index 00000000..83c329ff --- /dev/null +++ b/app/src/main/java/com/futo/platformplayer/mdns/DnsReader.kt @@ -0,0 +1,514 @@ +package com.futo.platformplayer.mdns + +import com.futo.platformplayer.mdns.Extensions.readDomainName +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.nio.charset.StandardCharsets +import kotlin.math.pow +import java.net.InetAddress + +data class PTRRecord(val domainName: String) + +data class ARecord(val address: InetAddress) + +data class AAAARecord(val address: InetAddress) + +data class MXRecord(val preference: UShort, val exchange: String) + +data class CNAMERecord(val cname: String) + +data class TXTRecord(val texts: List<String>) + +data class SOARecord( + val primaryNameServer: String, + val responsibleAuthorityMailbox: String, + val serialNumber: Int, + val refreshInterval: Int, + val retryInterval: Int, + val expiryLimit: Int, + val minimumTTL: Int +) + +data class SRVRecord(val priority: UShort, val weight: UShort, val port: UShort, val target: String) + +data class NSRecord(val nameServer: String) + +data class CAARecord(val flags: Byte, val tag: String, val value: String) + +data class HINFORecord(val cpu: String, val os: String) + +data class RPRecord(val mailbox: String, val txtDomainName: String) + + +data class AFSDBRecord(val subtype: UShort, val hostname: String) +data class LOCRecord( + val version: Byte, + val size: Double, + val horizontalPrecision: Double, + val verticalPrecision: Double, + val latitude: Double, + val longitude: Double, + val altitude: Double +) { + companion object { + fun decodeSizeOrPrecision(coded: Byte): Double { + val baseValue = (coded.toInt() shr 4) and 0x0F + val exponent = coded.toInt() and 0x0F + return baseValue * 10.0.pow(exponent.toDouble()) + } + + fun decodeLatitudeOrLongitude(coded: Int): Double { + val arcSeconds = coded / 1E3 + return arcSeconds / 3600.0 + } + + fun decodeAltitude(coded: Int): Double { + return (coded / 100.0) - 100000.0 + } + } +} + +data class NAPTRRecord( + val order: UShort, + val preference: UShort, + val flags: String, + val services: String, + val regexp: String, + val replacement: String +) + +data class RRSIGRecord( + val typeCovered: UShort, + val algorithm: Byte, + val labels: Byte, + val originalTTL: UInt, + val signatureExpiration: UInt, + val signatureInception: UInt, + val keyTag: UShort, + val signersName: String, + val signature: ByteArray +) + +data class KXRecord(val preference: UShort, val exchanger: String) + +data class CERTRecord(val type: UShort, val keyTag: UShort, val algorithm: Byte, val certificate: ByteArray) + + + +data class DNAMERecord(val target: String) + +data class DSRecord(val keyTag: UShort, val algorithm: Byte, val digestType: Byte, val digest: ByteArray) + +data class SSHFPRecord(val algorithm: Byte, val fingerprintType: Byte, val fingerprint: ByteArray) + +data class TLSARecord(val usage: Byte, val selector: Byte, val matchingType: Byte, val certificateAssociationData: ByteArray) + +data class SMIMEARecord(val usage: Byte, val selector: Byte, val matchingType: Byte, val certificateAssociationData: ByteArray) + +data class URIRecord(val priority: UShort, val weight: UShort, val target: String) + +data class NSECRecord(val ownerName: String, val typeBitMaps: List<Pair<Byte, ByteArray>>) +data class NSEC3Record( + val hashAlgorithm: Byte, + val flags: Byte, + val iterations: UShort, + val salt: ByteArray, + val nextHashedOwnerName: ByteArray, + val typeBitMaps: List<UShort> +) + +data class NSEC3PARAMRecord(val hashAlgorithm: Byte, val flags: Byte, val iterations: UShort, val salt: ByteArray) +data class SPFRecord(val texts: List<String>) +data class TKEYRecord( + val algorithm: String, + val inception: UInt, + val expiration: UInt, + val mode: UShort, + val error: UShort, + val keyData: ByteArray, + val otherData: ByteArray +) + +data class TSIGRecord( + val algorithmName: String, + val timeSigned: UInt, + val fudge: UShort, + val mac: ByteArray, + val originalID: UShort, + val error: UShort, + val otherData: ByteArray +) + +data class OPTRecordOption(val code: UShort, val data: ByteArray) +data class OPTRecord(val options: List<OPTRecordOption>) + +class DnsReader(private val data: ByteArray, private var position: Int = 0, private val length: Int = data.size) { + + private val endPosition: Int = position + length + + fun readDomainName(): String { + return data.asUByteArray().readDomainName(position).also { position = it.second }.first + } + + fun readDouble(): Double { + checkRemainingBytes(Double.SIZE_BYTES) + val result = ByteBuffer.wrap(data, position, Double.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).double + position += Double.SIZE_BYTES + return result + } + + fun readInt16(): Short { + checkRemainingBytes(Short.SIZE_BYTES) + val result = ByteBuffer.wrap(data, position, Short.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).short + position += Short.SIZE_BYTES + return result + } + + fun readInt32(): Int { + checkRemainingBytes(Int.SIZE_BYTES) + val result = ByteBuffer.wrap(data, position, Int.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).int + position += Int.SIZE_BYTES + return result + } + + fun readInt64(): Long { + checkRemainingBytes(Long.SIZE_BYTES) + val result = ByteBuffer.wrap(data, position, Long.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).long + position += Long.SIZE_BYTES + return result + } + + fun readSingle(): Float { + checkRemainingBytes(Float.SIZE_BYTES) + val result = ByteBuffer.wrap(data, position, Float.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).float + position += Float.SIZE_BYTES + return result + } + + fun readByte(): Byte { + checkRemainingBytes(Byte.SIZE_BYTES) + return data[position++] + } + + fun readBytes(length: Int): ByteArray { + checkRemainingBytes(length) + return ByteArray(length).also { data.copyInto(it, startIndex = position, endIndex = position + length) } + .also { position += length } + } + + fun readUInt16(): UShort { + checkRemainingBytes(Short.SIZE_BYTES) + val result = ByteBuffer.wrap(data, position, Short.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).short.toUShort() + position += Short.SIZE_BYTES + return result + } + + fun readUInt32(): UInt { + checkRemainingBytes(Int.SIZE_BYTES) + val result = ByteBuffer.wrap(data, position, Int.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).int.toUInt() + position += Int.SIZE_BYTES + return result + } + + fun readUInt64(): ULong { + checkRemainingBytes(Long.SIZE_BYTES) + val result = ByteBuffer.wrap(data, position, Long.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).long.toULong() + position += Long.SIZE_BYTES + return result + } + + fun readString(): String { + val length = data[position++].toInt() + checkRemainingBytes(length) + return String(data, position, length, StandardCharsets.UTF_8).also { position += length } + } + + private fun checkRemainingBytes(requiredBytes: Int) { + if (position + requiredBytes > endPosition) throw IndexOutOfBoundsException() + } + + fun readRPRecord(): RPRecord { + return RPRecord(readDomainName(), readDomainName()) + } + + fun readKXRecord(): KXRecord { + val preference = readUInt16() + val exchanger = readDomainName() + return KXRecord(preference, exchanger) + } + + fun readCERTRecord(): CERTRecord { + val type = readUInt16() + val keyTag = readUInt16() + val algorithm = readByte() + val certificateLength = readUInt16().toInt() - 5 + val certificate = readBytes(certificateLength) + return CERTRecord(type, keyTag, algorithm, certificate) + } + + fun readPTRRecord(): PTRRecord { + return PTRRecord(readDomainName()) + } + + fun readARecord(): ARecord { + val address = readBytes(4) + return ARecord(InetAddress.getByAddress(address)) + } + + fun readAAAARecord(): AAAARecord { + val address = readBytes(16) + return AAAARecord(InetAddress.getByAddress(address)) + } + + fun readMXRecord(): MXRecord { + val preference = readUInt16() + val exchange = readDomainName() + return MXRecord(preference, exchange) + } + + fun readCNAMERecord(): CNAMERecord { + return CNAMERecord(readDomainName()) + } + + fun readTXTRecord(): TXTRecord { + val texts = mutableListOf<String>() + while (position < endPosition) { + val textLength = data[position++].toInt() + checkRemainingBytes(textLength) + val text = String(data, position, textLength, StandardCharsets.UTF_8) + texts.add(text) + position += textLength + } + return TXTRecord(texts) + } + + fun readSOARecord(): SOARecord { + val primaryNameServer = readDomainName() + val responsibleAuthorityMailbox = readDomainName() + val serialNumber = readInt32() + val refreshInterval = readInt32() + val retryInterval = readInt32() + val expiryLimit = readInt32() + val minimumTTL = readInt32() + return SOARecord(primaryNameServer, responsibleAuthorityMailbox, serialNumber, refreshInterval, retryInterval, expiryLimit, minimumTTL) + } + + fun readSRVRecord(): SRVRecord { + val priority = readUInt16() + val weight = readUInt16() + val port = readUInt16() + val target = readDomainName() + return SRVRecord(priority, weight, port, target) + } + + fun readNSRecord(): NSRecord { + return NSRecord(readDomainName()) + } + + fun readCAARecord(): CAARecord { + val length = readUInt16().toInt() + val flags = readByte() + val tagLength = readByte().toInt() + val tag = String(data, position, tagLength, StandardCharsets.US_ASCII).also { position += tagLength } + val valueLength = length - 1 - 1 - tagLength + val value = String(data, position, valueLength, StandardCharsets.US_ASCII).also { position += valueLength } + return CAARecord(flags, tag, value) + } + + fun readHINFORecord(): HINFORecord { + val cpuLength = readByte().toInt() + val cpu = String(data, position, cpuLength, StandardCharsets.US_ASCII).also { position += cpuLength } + val osLength = readByte().toInt() + val os = String(data, position, osLength, StandardCharsets.US_ASCII).also { position += osLength } + return HINFORecord(cpu, os) + } + + fun readAFSDBRecord(): AFSDBRecord { + return AFSDBRecord(readUInt16(), readDomainName()) + } + + fun readLOCRecord(): LOCRecord { + val version = readByte() + val size = LOCRecord.decodeSizeOrPrecision(readByte()) + val horizontalPrecision = LOCRecord.decodeSizeOrPrecision(readByte()) + val verticalPrecision = LOCRecord.decodeSizeOrPrecision(readByte()) + val latitudeCoded = readInt32() + val longitudeCoded = readInt32() + val altitudeCoded = readInt32() + val latitude = LOCRecord.decodeLatitudeOrLongitude(latitudeCoded) + val longitude = LOCRecord.decodeLatitudeOrLongitude(longitudeCoded) + val altitude = LOCRecord.decodeAltitude(altitudeCoded) + return LOCRecord(version, size, horizontalPrecision, verticalPrecision, latitude, longitude, altitude) + } + + fun readNAPTRRecord(): NAPTRRecord { + val order = readUInt16() + val preference = readUInt16() + val flags = readString() + val services = readString() + val regexp = readString() + val replacement = readDomainName() + return NAPTRRecord(order, preference, flags, services, regexp, replacement) + } + + fun readDNAMERecord(): DNAMERecord { + return DNAMERecord(readDomainName()) + } + + fun readDSRecord(): DSRecord { + val keyTag = readUInt16() + val algorithm = readByte() + val digestType = readByte() + val digestLength = readUInt16().toInt() - 4 + val digest = readBytes(digestLength) + return DSRecord(keyTag, algorithm, digestType, digest) + } + + fun readSSHFPRecord(): SSHFPRecord { + val algorithm = readByte() + val fingerprintType = readByte() + val fingerprintLength = readUInt16().toInt() - 2 + val fingerprint = readBytes(fingerprintLength) + return SSHFPRecord(algorithm, fingerprintType, fingerprint) + } + + fun readTLSARecord(): TLSARecord { + val usage = readByte() + val selector = readByte() + val matchingType = readByte() + val dataLength = readUInt16().toInt() - 3 + val certificateAssociationData = readBytes(dataLength) + return TLSARecord(usage, selector, matchingType, certificateAssociationData) + } + + fun readSMIMEARecord(): SMIMEARecord { + val usage = readByte() + val selector = readByte() + val matchingType = readByte() + val dataLength = readUInt16().toInt() - 3 + val certificateAssociationData = readBytes(dataLength) + return SMIMEARecord(usage, selector, matchingType, certificateAssociationData) + } + + fun readURIRecord(): URIRecord { + val priority = readUInt16() + val weight = readUInt16() + val length = readUInt16().toInt() + val target = String(data, position, length, StandardCharsets.US_ASCII).also { position += length } + return URIRecord(priority, weight, target) + } + + fun readRRSIGRecord(): RRSIGRecord { + val typeCovered = readUInt16() + val algorithm = readByte() + val labels = readByte() + val originalTTL = readUInt32() + val signatureExpiration = readUInt32() + val signatureInception = readUInt32() + val keyTag = readUInt16() + val signersName = readDomainName() + val signatureLength = readUInt16().toInt() + val signature = readBytes(signatureLength) + return RRSIGRecord( + typeCovered, + algorithm, + labels, + originalTTL, + signatureExpiration, + signatureInception, + keyTag, + signersName, + signature + ) + } + + fun readNSECRecord(): NSECRecord { + val ownerName = readDomainName() + val typeBitMaps = mutableListOf<Pair<Byte, ByteArray>>() + while (position < endPosition) { + val windowBlock = readByte() + val bitmapLength = readByte().toInt() + val bitmap = readBytes(bitmapLength) + typeBitMaps.add(windowBlock to bitmap) + } + return NSECRecord(ownerName, typeBitMaps) + } + + fun readNSEC3Record(): NSEC3Record { + val hashAlgorithm = readByte() + val flags = readByte() + val iterations = readUInt16() + val saltLength = readByte().toInt() + val salt = readBytes(saltLength) + val hashLength = readByte().toInt() + val nextHashedOwnerName = readBytes(hashLength) + val bitMapLength = readUInt16().toInt() + val typeBitMaps = mutableListOf<UShort>() + val endPos = position + bitMapLength + while (position < endPos) { + typeBitMaps.add(readUInt16()) + } + return NSEC3Record(hashAlgorithm, flags, iterations, salt, nextHashedOwnerName, typeBitMaps) + } + + fun readNSEC3PARAMRecord(): NSEC3PARAMRecord { + val hashAlgorithm = readByte() + val flags = readByte() + val iterations = readUInt16() + val saltLength = readByte().toInt() + val salt = readBytes(saltLength) + return NSEC3PARAMRecord(hashAlgorithm, flags, iterations, salt) + } + + + fun readSPFRecord(): SPFRecord { + val length = readUInt16().toInt() + val texts = mutableListOf<String>() + val endPos = position + length + while (position < endPos) { + val textLength = readByte().toInt() + val text = String(data, position, textLength, StandardCharsets.US_ASCII).also { position += textLength } + texts.add(text) + } + return SPFRecord(texts) + } + + fun readTKEYRecord(): TKEYRecord { + val algorithm = readDomainName() + val inception = readUInt32() + val expiration = readUInt32() + val mode = readUInt16() + val error = readUInt16() + val keySize = readUInt16().toInt() + val keyData = readBytes(keySize) + val otherSize = readUInt16().toInt() + val otherData = readBytes(otherSize) + return TKEYRecord(algorithm, inception, expiration, mode, error, keyData, otherData) + } + + fun readTSIGRecord(): TSIGRecord { + val algorithmName = readDomainName() + val timeSigned = readUInt32() + val fudge = readUInt16() + val macSize = readUInt16().toInt() + val mac = readBytes(macSize) + val originalID = readUInt16() + val error = readUInt16() + val otherSize = readUInt16().toInt() + val otherData = readBytes(otherSize) + return TSIGRecord(algorithmName, timeSigned, fudge, mac, originalID, error, otherData) + } + + + + fun readOPTRecord(): OPTRecord { + val options = mutableListOf<OPTRecordOption>() + while (position < endPosition) { + val optionCode = readUInt16() + val optionLength = readUInt16().toInt() + val optionData = readBytes(optionLength) + options.add(OPTRecordOption(optionCode, optionData)) + } + return OPTRecord(options) + } +} diff --git a/app/src/main/java/com/futo/platformplayer/mdns/DnsResourceRecord.kt b/app/src/main/java/com/futo/platformplayer/mdns/DnsResourceRecord.kt new file mode 100644 index 00000000..87ec0e5f --- /dev/null +++ b/app/src/main/java/com/futo/platformplayer/mdns/DnsResourceRecord.kt @@ -0,0 +1,117 @@ +package com.futo.platformplayer.mdns + +import com.futo.platformplayer.mdns.Extensions.readDomainName + +enum class ResourceRecordType(val value: UShort) { + None(0u), + A(1u), + NS(2u), + MD(3u), + MF(4u), + CNAME(5u), + SOA(6u), + MB(7u), + MG(8u), + MR(9u), + NULL(10u), + WKS(11u), + PTR(12u), + HINFO(13u), + MINFO(14u), + MX(15u), + TXT(16u), + RP(17u), + AFSDB(18u), + SIG(24u), + KEY(25u), + AAAA(28u), + LOC(29u), + SRV(33u), + NAPTR(35u), + KX(36u), + CERT(37u), + DNAME(39u), + APL(42u), + DS(43u), + SSHFP(44u), + IPSECKEY(45u), + RRSIG(46u), + NSEC(47u), + DNSKEY(48u), + DHCID(49u), + NSEC3(50u), + NSEC3PARAM(51u), + TSLA(52u), + SMIMEA(53u), + HIP(55u), + CDS(59u), + CDNSKEY(60u), + OPENPGPKEY(61u), + CSYNC(62u), + ZONEMD(63u), + SVCB(64u), + HTTPS(65u), + EUI48(108u), + EUI64(109u), + TKEY(249u), + TSIG(250u), + URI(256u), + CAA(257u), + TA(32768u), + DLV(32769u), + AXFR(252u), + IXFR(251u), + OPT(41u) +} + +enum class ResourceRecordClass(val value: UShort) { + IN(1u), + CS(2u), + CH(3u), + HS(4u) +} + +data class DnsResourceRecord( + override val name: String, + override val type: Int, + override val clazz: Int, + val timeToLive: UInt, + val cacheFlush: Boolean, + val dataPosition: Int = -1, + val dataLength: Int = -1, + private val data: ByteArray? = null +) : DnsResourceRecordBase(name, type, clazz) { + + companion object { + fun parse(data: ByteArray, startPosition: Int): Pair<DnsResourceRecord, Int> { + val span = data.asUByteArray() + var position = startPosition + val name = span.readDomainName(position).also { position = it.second } + val type = (span[position].toInt() shl 8 or span[position + 1].toInt()).toUShort() + position += 2 + val clazz = (span[position].toInt() shl 8 or span[position + 1].toInt()).toUShort() + position += 2 + val ttl = (span[position].toInt() shl 24 or (span[position + 1].toInt() shl 16) or + (span[position + 2].toInt() shl 8) or span[position + 3].toInt()).toUInt() + position += 4 + val rdlength = (span[position].toInt() shl 8 or span[position + 1].toInt()).toUShort() + val rdposition = position + 2 + position += 2 + rdlength.toInt() + + return DnsResourceRecord( + name = name.first, + type = type.toInt(), + clazz = clazz.toInt() and 0b1111111_11111111, + timeToLive = ttl, + cacheFlush = ((clazz.toInt() shr 15) and 0b1) != 0, + dataPosition = rdposition, + dataLength = rdlength.toInt(), + data = data + ) to position + } + } + + fun getDataReader(): DnsReader { + return DnsReader(data!!, dataPosition, dataLength) + } +} diff --git a/app/src/main/java/com/futo/platformplayer/mdns/DnsWriter.kt b/app/src/main/java/com/futo/platformplayer/mdns/DnsWriter.kt new file mode 100644 index 00000000..5b2c1f5c --- /dev/null +++ b/app/src/main/java/com/futo/platformplayer/mdns/DnsWriter.kt @@ -0,0 +1,208 @@ +package com.futo.platformplayer.mdns + +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.nio.charset.StandardCharsets + +class DnsWriter { + private val data = mutableListOf<Byte>() + private val namePositions = mutableMapOf<String, Int>() + + fun toByteArray(): ByteArray = data.toByteArray() + + fun writePacket( + header: DnsPacketHeader, + questionCount: Int? = null, questionWriter: ((DnsWriter, Int) -> Unit)? = null, + answerCount: Int? = null, answerWriter: ((DnsWriter, Int) -> Unit)? = null, + authorityCount: Int? = null, authorityWriter: ((DnsWriter, Int) -> Unit)? = null, + additionalsCount: Int? = null, additionalWriter: ((DnsWriter, Int) -> Unit)? = null + ) { + if (questionCount != null && questionWriter == null || questionCount == null && questionWriter != null) + throw Exception("When question count is given, question writer should also be given.") + if (answerCount != null && answerWriter == null || answerCount == null && answerWriter != null) + throw Exception("When answer count is given, answer writer should also be given.") + if (authorityCount != null && authorityWriter == null || authorityCount == null && authorityWriter != null) + throw Exception("When authority count is given, authority writer should also be given.") + if (additionalsCount != null && additionalWriter == null || additionalsCount == null && additionalWriter != null) + throw Exception("When additionals count is given, additional writer should also be given.") + + writeHeader(header, questionCount ?: 0, answerCount ?: 0, authorityCount ?: 0, additionalsCount ?: 0) + + repeat(questionCount ?: 0) { questionWriter?.invoke(this, it) } + repeat(answerCount ?: 0) { answerWriter?.invoke(this, it) } + repeat(authorityCount ?: 0) { authorityWriter?.invoke(this, it) } + repeat(additionalsCount ?: 0) { additionalWriter?.invoke(this, it) } + } + + fun writeHeader(header: DnsPacketHeader, questionCount: Int, answerCount: Int, authorityCount: Int, additionalsCount: Int) { + write(header.identifier) + + var flags: UShort = 0u + flags = flags or ((header.queryResponse.toUInt() and 0xFFFFu) shl 15).toUShort() + flags = flags or ((header.opcode.toUInt() and 0xFFFFu) shl 11).toUShort() + flags = flags or ((if (header.authoritativeAnswer) 1u else 0u) shl 10).toUShort() + flags = flags or ((if (header.truncated) 1u else 0u) shl 9).toUShort() + flags = flags or ((if (header.recursionDesired) 1u else 0u) shl 8).toUShort() + flags = flags or ((if (header.recursionAvailable) 1u else 0u) shl 7).toUShort() + flags = flags or ((if (header.answerAuthenticated) 1u else 0u) shl 5).toUShort() + flags = flags or ((if (header.nonAuthenticatedData) 1u else 0u) shl 4).toUShort() + flags = flags or header.responseCode.value.toUShort() + write(flags) + + write(questionCount.toUShort()) + write(answerCount.toUShort()) + write(authorityCount.toUShort()) + write(additionalsCount.toUShort()) + } + + fun writeDomainName(name: String) { + synchronized(namePositions) { + val labels = name.split('.') + for (label in labels) { + val nameAtOffset = name.substring(name.indexOf(label)) + if (namePositions.containsKey(nameAtOffset)) { + val position = namePositions[nameAtOffset]!! + val pointer = (0b11000000_00000000 or position).toUShort() + write(pointer) + return + } + if (label.isNotEmpty()) { + val labelBytes = label.toByteArray(StandardCharsets.UTF_8) + val nameStartPos = data.size + write(labelBytes.size.toByte()) + write(labelBytes) + namePositions[nameAtOffset] = nameStartPos + } + } + write(0.toByte()) // End of domain name + } + } + + fun write(value: DnsResourceRecord, dataWriter: (DnsWriter) -> Unit) { + writeDomainName(value.name) + write(value.type.toUShort()) + val cls = ((if (value.cacheFlush) 1u else 0u) shl 15).toUShort() or value.clazz.toUShort() + write(cls) + write(value.timeToLive) + + val lengthOffset = data.size + write(0.toUShort()) + dataWriter(this) + val rdLength = data.size - lengthOffset - 2 + val rdLengthBytes = ByteBuffer.allocate(2).order(ByteOrder.BIG_ENDIAN).putShort(rdLength.toShort()).array() + data[lengthOffset] = rdLengthBytes[0] + data[lengthOffset + 1] = rdLengthBytes[1] + } + + fun write(value: DnsQuestion) { + writeDomainName(value.name) + write(value.type.toUShort()) + write(((if (value.queryUnicast) 1u else 0u shl 15).toUShort() or value.clazz.toUShort())) + } + + fun write(value: Double) { + val bytes = ByteBuffer.allocate(Double.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putDouble(value).array() + write(bytes) + } + + fun write(value: Short) { + val bytes = ByteBuffer.allocate(Short.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putShort(value).array() + write(bytes) + } + + fun write(value: Int) { + val bytes = ByteBuffer.allocate(Int.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putInt(value).array() + write(bytes) + } + + fun write(value: Long) { + val bytes = ByteBuffer.allocate(Long.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putLong(value).array() + write(bytes) + } + + fun write(value: Float) { + val bytes = ByteBuffer.allocate(Float.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putFloat(value).array() + write(bytes) + } + + fun write(value: Byte) { + data.add(value) + } + + fun write(value: ByteArray) { + data.addAll(value.asIterable()) + } + + fun write(value: ByteArray, offset: Int, length: Int) { + data.addAll(value.slice(offset until offset + length)) + } + + fun write(value: UShort) { + val bytes = ByteBuffer.allocate(Short.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putShort(value.toShort()).array() + write(bytes) + } + + fun write(value: UInt) { + val bytes = ByteBuffer.allocate(Int.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putInt(value.toInt()).array() + write(bytes) + } + + fun write(value: ULong) { + val bytes = ByteBuffer.allocate(Long.SIZE_BYTES).order(ByteOrder.BIG_ENDIAN).putLong(value.toLong()).array() + write(bytes) + } + + fun write(value: String) { + val bytes = value.toByteArray(StandardCharsets.UTF_8) + write(bytes.size.toByte()) + write(bytes) + } + + fun write(value: PTRRecord) { + writeDomainName(value.domainName) + } + + fun write(value: ARecord) { + val bytes = value.address.address + if (bytes.size != 4) throw Exception("Unexpected amount of address bytes.") + write(bytes) + } + + fun write(value: AAAARecord) { + val bytes = value.address.address + if (bytes.size != 16) throw Exception("Unexpected amount of address bytes.") + write(bytes) + } + + fun write(value: TXTRecord) { + value.texts.forEach { + val bytes = it.toByteArray(StandardCharsets.UTF_8) + write(bytes.size.toByte()) + write(bytes) + } + } + + fun write(value: SRVRecord) { + write(value.priority) + write(value.weight) + write(value.port) + writeDomainName(value.target) + } + + fun write(value: NSECRecord) { + writeDomainName(value.ownerName) + value.typeBitMaps.forEach { (windowBlock, bitmap) -> + write(windowBlock) + write(bitmap.size.toByte()) + write(bitmap) + } + } + + fun write(value: OPTRecord) { + value.options.forEach { option -> + write(option.code) + write(option.data.size.toUShort()) + write(option.data) + } + } +} \ No newline at end of file diff --git a/app/src/main/java/com/futo/platformplayer/mdns/Extensions.kt b/app/src/main/java/com/futo/platformplayer/mdns/Extensions.kt new file mode 100644 index 00000000..48bb4c6a --- /dev/null +++ b/app/src/main/java/com/futo/platformplayer/mdns/Extensions.kt @@ -0,0 +1,63 @@ +package com.futo.platformplayer.mdns + +import android.util.Log + +object Extensions { + fun ByteArray.toByteDump(): String { + val result = StringBuilder() + for (i in indices) { + result.append(String.format("%02X ", this[i])) + + if ((i + 1) % 16 == 0 || i == size - 1) { + val padding = 3 * (16 - (i % 16 + 1)) + if (i == size - 1 && (i + 1) % 16 != 0) result.append(" ".repeat(padding)) + + result.append("; ") + val start = i - (i % 16) + val end = minOf(i, size - 1) + for (j in start..end) { + val ch = if (this[j] in 32..127) this[j].toChar() else '.' + result.append(ch) + } + if (i != size - 1) result.appendLine() + } + } + return result.toString() + } + + fun UByteArray.readDomainName(startPosition: Int): Pair<String, Int> { + var position = startPosition + return readDomainName(position, 0) + } + + private fun UByteArray.readDomainName(position: Int, depth: Int = 0): Pair<String, Int> { + if (depth > 16) throw Exception("Exceeded maximum recursion depth in DNS packet. Possible circular reference.") + + val domainParts = mutableListOf<String>() + var newPosition = position + + while (true) { + if (newPosition < 0) + println() + + val length = this[newPosition].toUByte() + if ((length and 0b11000000u).toUInt() == 0b11000000u) { + val offset = (((length and 0b00111111u).toUInt()) shl 8) or this[newPosition + 1].toUInt() + val (part, _) = this.readDomainName(offset.toInt(), depth + 1) + domainParts.add(part) + newPosition += 2 + break + } else if (length.toUInt() == 0u) { + newPosition++ + break + } else { + newPosition++ + val part = String(this.asByteArray(), newPosition, length.toInt(), Charsets.UTF_8) + domainParts.add(part) + newPosition += length.toInt() + } + } + + return domainParts.joinToString(".") to newPosition + } +} \ No newline at end of file diff --git a/app/src/main/java/com/futo/platformplayer/mdns/MDNSListener.kt b/app/src/main/java/com/futo/platformplayer/mdns/MDNSListener.kt new file mode 100644 index 00000000..494c4934 --- /dev/null +++ b/app/src/main/java/com/futo/platformplayer/mdns/MDNSListener.kt @@ -0,0 +1,482 @@ +package com.futo.platformplayer.mdns + +import com.futo.platformplayer.logging.Logger +import kotlinx.coroutines.* +import java.net.* +import java.util.* +import java.util.concurrent.locks.ReentrantLock +import kotlin.concurrent.withLock + +class MDNSListener { + companion object { + private val TAG = "MDNSListener" + const val MulticastPort = 5353 + val MulticastAddressIPv4: InetAddress = InetAddress.getByName("224.0.0.251") + val MulticastAddressIPv6: InetAddress = InetAddress.getByName("FF02::FB") + val MdnsEndpointIPv6: InetSocketAddress = InetSocketAddress(MulticastAddressIPv6, MulticastPort) + val MdnsEndpointIPv4: InetSocketAddress = InetSocketAddress(MulticastAddressIPv4, MulticastPort) + } + + private val _lockObject = ReentrantLock() + private var _receiver4: DatagramSocket? = null + private var _receiver6: DatagramSocket? = null + private val _senders = mutableListOf<DatagramSocket>() + private val _nicMonitor = NICMonitor() + private val _serviceRecordAggregator = ServiceRecordAggregator() + private var _started = false + private var _threadReceiver4: Thread? = null + private var _threadReceiver6: Thread? = null + private var _scope: CoroutineScope? = null + + var onPacket: ((DnsPacket) -> Unit)? = null + var onServicesUpdated: ((List<DnsService>) -> Unit)? = null + + private val _recordLockObject = ReentrantLock() + private val _recordsA = mutableListOf<Pair<DnsResourceRecord, ARecord>>() + private val _recordsAAAA = mutableListOf<Pair<DnsResourceRecord, AAAARecord>>() + private val _recordsPTR = mutableListOf<Pair<DnsResourceRecord, PTRRecord>>() + private val _recordsTXT = mutableListOf<Pair<DnsResourceRecord, TXTRecord>>() + private val _recordsSRV = mutableListOf<Pair<DnsResourceRecord, SRVRecord>>() + private val _services = mutableListOf<BroadcastService>() + + init { + _nicMonitor.added = { onNicsAdded(it) } + _nicMonitor.removed = { onNicsRemoved(it) } + _serviceRecordAggregator.onServicesUpdated = { onServicesUpdated?.invoke(it) } + } + + fun start() { + if (_started) throw Exception("Already running.") + _started = true + + _scope = CoroutineScope(Dispatchers.IO); + + Logger.i(TAG, "Starting") + _lockObject.withLock { + val receiver4 = DatagramSocket(null).apply { + reuseAddress = true + bind(InetSocketAddress(InetAddress.getByName("0.0.0.0"), MulticastPort)) + } + _receiver4 = receiver4 + + val receiver6 = DatagramSocket(null).apply { + reuseAddress = true + bind(InetSocketAddress(InetAddress.getByName("::"), MulticastPort)) + } + _receiver6 = receiver6 + + _nicMonitor.start() + _serviceRecordAggregator.start() + onNicsAdded(_nicMonitor.current) + + _threadReceiver4 = Thread { + receiveLoop(receiver4) + }.apply { start() } + + _threadReceiver6 = Thread { + receiveLoop(receiver6) + }.apply { start() } + } + } + + fun queryServices(names: Array<String>) { + if (names.isEmpty()) throw IllegalArgumentException("At least one name must be specified.") + + val writer = DnsWriter() + writer.writePacket( + DnsPacketHeader( + identifier = 0u, + queryResponse = QueryResponse.Query.value.toInt(), + opcode = DnsOpcode.StandardQuery.value.toInt(), + truncated = false, + nonAuthenticatedData = false, + recursionDesired = false, + answerAuthenticated = false, + authoritativeAnswer = false, + recursionAvailable = false, + responseCode = DnsResponseCode.NoError + ), + questionCount = names.size, + questionWriter = { w, i -> + w.write( + DnsQuestion( + name = names[i], + type = QuestionType.PTR.value.toInt(), + clazz = QuestionClass.IN.value.toInt(), + queryUnicast = false + ) + ) + } + ) + + send(writer.toByteArray()) + } + + private fun send(data: ByteArray) { + _lockObject.withLock { + for (sender in _senders) { + try { + val endPoint = if (sender.localAddress is Inet4Address) MdnsEndpointIPv4 else MdnsEndpointIPv6 + sender.send(DatagramPacket(data, data.size, endPoint)) + } catch (e: Exception) { + Logger.i(TAG, "Failed to send on ${sender.localSocketAddress}: ${e.message}.") + } + } + } + } + + fun queryAllQuestions(names: Array<String>) { + if (names.isEmpty()) throw IllegalArgumentException("At least one name must be specified.") + + val questions = names.flatMap { _serviceRecordAggregator.getAllQuestions(it) } + questions.groupBy { it.name }.forEach { (_, questionsForHost) -> + val writer = DnsWriter() + writer.writePacket( + DnsPacketHeader( + identifier = 0u, + queryResponse = QueryResponse.Query.value.toInt(), + opcode = DnsOpcode.StandardQuery.value.toInt(), + truncated = false, + nonAuthenticatedData = false, + recursionDesired = false, + answerAuthenticated = false, + authoritativeAnswer = false, + recursionAvailable = false, + responseCode = DnsResponseCode.NoError + ), + questionCount = questionsForHost.size, + questionWriter = { w, i -> w.write(questionsForHost[i]) } + ) + send(writer.toByteArray()) + } + } + + private fun onNicsAdded(nics: List<NetworkInterface>) { + _lockObject.withLock { + if (!_started) return + + val addresses = nics.flatMap { nic -> + nic.interfaceAddresses.map { it.address } + .filter { it is Inet4Address || (it is Inet6Address && it.isLinkLocalAddress) } + } + + addresses.forEach { address -> + Logger.i(TAG, "New address discovered $address") + + try { + when (address) { + is Inet4Address -> { + val sender = MulticastSocket(null).apply { + reuseAddress = true + bind(InetSocketAddress(address, MulticastPort)) + joinGroup(InetSocketAddress(MulticastAddressIPv4, MulticastPort), NetworkInterface.getByInetAddress(address)) + } + _senders.add(sender) + } + + is Inet6Address -> { + val sender = MulticastSocket(null).apply { + reuseAddress = true + bind(InetSocketAddress(address, MulticastPort)) + joinGroup(InetSocketAddress(MulticastAddressIPv6, MulticastPort), NetworkInterface.getByInetAddress(address)) + } + _senders.add(sender) + } + + else -> throw UnsupportedOperationException("Address type ${address.javaClass.name} is not supported.") + } + } catch (e: Exception) { + Logger.i(TAG, "Exception occurred when processing added address $address: ${e.message}.") + // Close the socket if there was an error + (_senders.lastOrNull() as? MulticastSocket)?.close() + } + } + } + + if (nics.isNotEmpty()) { + try { + updateBroadcastRecords() + broadcastRecords() + } catch (e: Exception) { + Logger.i(TAG, "Exception occurred when broadcasting records: ${e.message}.") + } + } + } + + private fun onNicsRemoved(nics: List<NetworkInterface>) { + _lockObject.withLock { + if (!_started) return + //TODO: Cleanup? + } + + if (nics.isNotEmpty()) { + try { + updateBroadcastRecords() + broadcastRecords() + } catch (e: Exception) { + Logger.e(TAG, "Exception occurred when broadcasting records", e) + } + } + } + + private fun receiveLoop(client: DatagramSocket) { + Logger.i(TAG, "Started receive loop") + + val buffer = ByteArray(1024) + val packet = DatagramPacket(buffer, buffer.size) + while (_started) { + try { + client.receive(packet) + handleResult(packet) + } catch (e: Exception) { + Logger.e(TAG, "An exception occurred while handling UDP result:", e) + } + } + + Logger.i(TAG, "Stopped receive loop") + } + + fun broadcastService( + deviceName: String, + serviceName: String, + port: UShort, + ttl: UInt = 120u, + weight: UShort = 0u, + priority: UShort = 0u, + texts: List<String>? = null + ) { + _recordLockObject.withLock { + _services.add( + BroadcastService( + deviceName = deviceName, + port = port, + priority = priority, + serviceName = serviceName, + texts = texts, + ttl = ttl, + weight = weight + ) + ) + } + + updateBroadcastRecords() + broadcastRecords() + } + + private fun updateBroadcastRecords() { + _recordLockObject.withLock { + _recordsSRV.clear() + _recordsPTR.clear() + _recordsA.clear() + _recordsAAAA.clear() + _recordsTXT.clear() + + _services.forEach { service -> + val id = UUID.randomUUID().toString() + val deviceDomainName = "${service.deviceName}.${service.serviceName}" + val addressName = "$id.local" + + _recordsSRV.add( + DnsResourceRecord( + clazz = ResourceRecordClass.IN.value.toInt(), + type = ResourceRecordType.SRV.value.toInt(), + timeToLive = service.ttl, + name = deviceDomainName, + cacheFlush = false + ) to SRVRecord( + target = addressName, + port = service.port, + priority = service.priority, + weight = service.weight + ) + ) + + _recordsPTR.add( + DnsResourceRecord( + clazz = ResourceRecordClass.IN.value.toInt(), + type = ResourceRecordType.PTR.value.toInt(), + timeToLive = service.ttl, + name = service.serviceName, + cacheFlush = false + ) to PTRRecord( + domainName = deviceDomainName + ) + ) + + val addresses = _nicMonitor.current.flatMap { nic -> + nic.interfaceAddresses.map { it.address } + } + + addresses.forEach { address -> + when (address) { + is Inet4Address -> _recordsA.add( + DnsResourceRecord( + clazz = ResourceRecordClass.IN.value.toInt(), + type = ResourceRecordType.A.value.toInt(), + timeToLive = service.ttl, + name = addressName, + cacheFlush = false + ) to ARecord( + address = address + ) + ) + + is Inet6Address -> _recordsAAAA.add( + DnsResourceRecord( + clazz = ResourceRecordClass.IN.value.toInt(), + type = ResourceRecordType.AAAA.value.toInt(), + timeToLive = service.ttl, + name = addressName, + cacheFlush = false + ) to AAAARecord( + address = address + ) + ) + + else -> Logger.i(TAG, "Invalid address type: $address.") + } + } + + if (service.texts != null) { + _recordsTXT.add( + DnsResourceRecord( + clazz = ResourceRecordClass.IN.value.toInt(), + type = ResourceRecordType.TXT.value.toInt(), + timeToLive = service.ttl, + name = deviceDomainName, + cacheFlush = false + ) to TXTRecord( + texts = service.texts + ) + ) + } + } + } + } + + private fun broadcastRecords(questions: List<DnsQuestion>? = null) { + val writer = DnsWriter() + _recordLockObject.withLock { + val recordsA: List<Pair<DnsResourceRecord, ARecord>> + val recordsAAAA: List<Pair<DnsResourceRecord, AAAARecord>> + val recordsPTR: List<Pair<DnsResourceRecord, PTRRecord>> + val recordsTXT: List<Pair<DnsResourceRecord, TXTRecord>> + val recordsSRV: List<Pair<DnsResourceRecord, SRVRecord>> + + if (questions != null) { + recordsA = _recordsA.filter { r -> questions.any { q -> q.name == r.first.name && q.clazz == r.first.clazz && q.type == r.first.type } } + recordsAAAA = _recordsAAAA.filter { r -> questions.any { q -> q.name == r.first.name && q.clazz == r.first.clazz && q.type == r.first.type } } + recordsPTR = _recordsPTR.filter { r -> questions.any { q -> q.name == r.first.name && q.clazz == r.first.clazz && q.type == r.first.type } } + recordsSRV = _recordsSRV.filter { r -> questions.any { q -> q.name == r.first.name && q.clazz == r.first.clazz && q.type == r.first.type } } + recordsTXT = _recordsTXT.filter { r -> questions.any { q -> q.name == r.first.name && q.clazz == r.first.clazz && q.type == r.first.type } } + } else { + recordsA = _recordsA + recordsAAAA = _recordsAAAA + recordsPTR = _recordsPTR + recordsSRV = _recordsSRV + recordsTXT = _recordsTXT + } + + val answerCount = recordsA.size + recordsAAAA.size + recordsPTR.size + recordsSRV.size + recordsTXT.size + if (answerCount < 1) return + + val txtOffset = recordsA.size + recordsAAAA.size + recordsPTR.size + recordsSRV.size + val srvOffset = recordsA.size + recordsAAAA.size + recordsPTR.size + val ptrOffset = recordsA.size + recordsAAAA.size + val aaaaOffset = recordsA.size + + writer.writePacket( + DnsPacketHeader( + identifier = 0u, + queryResponse = QueryResponse.Response.value.toInt(), + opcode = DnsOpcode.StandardQuery.value.toInt(), + truncated = false, + nonAuthenticatedData = false, + recursionDesired = false, + answerAuthenticated = false, + authoritativeAnswer = true, + recursionAvailable = false, + responseCode = DnsResponseCode.NoError + ), + answerCount = answerCount, + answerWriter = { w, i -> + when { + i >= txtOffset -> { + val record = recordsTXT[i - txtOffset] + w.write(record.first) { it.write(record.second) } + } + + i >= srvOffset -> { + val record = recordsSRV[i - srvOffset] + w.write(record.first) { it.write(record.second) } + } + + i >= ptrOffset -> { + val record = recordsPTR[i - ptrOffset] + w.write(record.first) { it.write(record.second) } + } + + i >= aaaaOffset -> { + val record = recordsAAAA[i - aaaaOffset] + w.write(record.first) { it.write(record.second) } + } + + else -> { + val record = recordsA[i] + w.write(record.first) { it.write(record.second) } + } + } + } + ) + } + + send(writer.toByteArray()) + } + + private fun handleResult(result: DatagramPacket) { + try { + val packet = DnsPacket.parse(result.data) + if (packet.questions.isNotEmpty()) { + _scope?.launch(Dispatchers.IO) { + try { + broadcastRecords(packet.questions) + } catch (e: Throwable) { + Logger.i(TAG, "Broadcasting records failed", e) + } + } + + } + _serviceRecordAggregator.add(packet) + onPacket?.invoke(packet) + } catch (e: Exception) { + Logger.v(TAG, "Failed to handle packet: ${Base64.getEncoder().encodeToString(result.data.slice(IntRange(0, result.length - 1)).toByteArray())}", e) + } + } + + fun stop() { + _lockObject.withLock { + _started = false + + _scope?.cancel() + _scope = null + + _nicMonitor.stop() + _serviceRecordAggregator.stop() + + _receiver4?.close() + _receiver4 = null + + _receiver6?.close() + _receiver6 = null + + _senders.forEach { it.close() } + _senders.clear() + } + + _threadReceiver4?.join() + _threadReceiver4 = null + + _threadReceiver6?.join() + _threadReceiver6 = null + } +} diff --git a/app/src/main/java/com/futo/platformplayer/mdns/NICMonitor.kt b/app/src/main/java/com/futo/platformplayer/mdns/NICMonitor.kt new file mode 100644 index 00000000..884e1514 --- /dev/null +++ b/app/src/main/java/com/futo/platformplayer/mdns/NICMonitor.kt @@ -0,0 +1,66 @@ +package com.futo.platformplayer.mdns + +import kotlinx.coroutines.* +import java.net.NetworkInterface + +class NICMonitor { + private val lockObject = Any() + private val nics = mutableListOf<NetworkInterface>() + private var cts: Job? = null + + val current: List<NetworkInterface> + get() = synchronized(nics) { nics.toList() } + + var added: ((List<NetworkInterface>) -> Unit)? = null + var removed: ((List<NetworkInterface>) -> Unit)? = null + + fun start() { + synchronized(lockObject) { + if (cts != null) throw Exception("Already started.") + + cts = CoroutineScope(Dispatchers.Default).launch { + loopAsync() + } + } + + nics.clear() + nics.addAll(getCurrentInterfaces().toList()) + } + + fun stop() { + synchronized(lockObject) { + cts?.cancel() + cts = null + } + + synchronized(nics) { + nics.clear() + } + } + + private suspend fun loopAsync() { + while (cts?.isActive == true) { + try { + val currentNics = getCurrentInterfaces().toList() + removed?.invoke(nics.filter { k -> currentNics.none { n -> k.name == n.name } }) + added?.invoke(currentNics.filter { nic -> nics.none { k -> k.name == nic.name } }) + + synchronized(nics) { + nics.clear() + nics.addAll(currentNics) + } + } catch (ex: Exception) { + // Ignored + } + delay(5000) + } + } + + private fun getCurrentInterfaces(): List<NetworkInterface> { + val nics = NetworkInterface.getNetworkInterfaces().toList() + .filter { it.isUp && !it.isLoopback } + + return if (nics.isNotEmpty()) nics else NetworkInterface.getNetworkInterfaces().toList() + .filter { it.isUp } + } +} diff --git a/app/src/main/java/com/futo/platformplayer/mdns/ServiceDiscoverer.kt b/app/src/main/java/com/futo/platformplayer/mdns/ServiceDiscoverer.kt new file mode 100644 index 00000000..79d29736 --- /dev/null +++ b/app/src/main/java/com/futo/platformplayer/mdns/ServiceDiscoverer.kt @@ -0,0 +1,68 @@ +package com.futo.platformplayer.mdns + +import com.futo.platformplayer.logging.Logger +import java.lang.Thread.sleep + +class ServiceDiscoverer(names: Array<String>, private val _onServicesUpdated: (List<DnsService>) -> Unit) { + private val _names: Array<String> + private var _listener: MDNSListener? = null + private var _started = false + private var _thread: Thread? = null + + init { + if (names.isEmpty()) throw IllegalArgumentException("At least one name must be specified.") + _names = names + } + + fun broadcastService( + deviceName: String, + serviceName: String, + port: UShort, + ttl: UInt = 120u, + weight: UShort = 0u, + priority: UShort = 0u, + texts: List<String>? = null + ) { + _listener?.let { + it.broadcastService(deviceName, serviceName, port, ttl, weight, priority, texts) + } + } + + fun stop() { + _started = false + _listener?.stop() + _listener = null + _thread?.join() + _thread = null + } + + fun start() { + if (_started) throw Exception("Already running.") + _started = true + + val listener = MDNSListener() + _listener = listener + listener.onServicesUpdated = { _onServicesUpdated?.invoke(it) } + listener.start() + + _thread = Thread { + try { + sleep(2000) + + while (_started) { + listener.queryServices(_names) + sleep(2000) + listener.queryAllQuestions(_names) + sleep(2000) + } + } catch (e: Throwable) { + Logger.i(TAG, "Exception in loop thread", e) + stop() + } + }.apply { start() } + } + + companion object { + private val TAG = "ServiceDiscoverer" + } +} diff --git a/app/src/main/java/com/futo/platformplayer/mdns/ServiceRecordAggregator.kt b/app/src/main/java/com/futo/platformplayer/mdns/ServiceRecordAggregator.kt new file mode 100644 index 00000000..5ff21a2c --- /dev/null +++ b/app/src/main/java/com/futo/platformplayer/mdns/ServiceRecordAggregator.kt @@ -0,0 +1,219 @@ +package com.futo.platformplayer.mdns + +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.delay +import kotlinx.coroutines.isActive +import kotlinx.coroutines.launch +import java.net.InetAddress +import java.util.Date + +data class DnsService( + var name: String, + var target: String, + var port: UShort, + val addresses: MutableList<InetAddress> = mutableListOf(), + val pointers: MutableList<String> = mutableListOf(), + val texts: MutableList<String> = mutableListOf() +) + +data class CachedDnsAddressRecord( + val expirationTime: Date, + val address: InetAddress +) + +data class CachedDnsTxtRecord( + val expirationTime: Date, + val texts: List<String> +) + +data class CachedDnsPtrRecord( + val expirationTime: Date, + val target: String +) + +data class CachedDnsSrvRecord( + val expirationTime: Date, + val service: SRVRecord +) + +class ServiceRecordAggregator { + private val _lockObject = Any() + private val _cachedAddressRecords = mutableMapOf<String, MutableList<CachedDnsAddressRecord>>() + private val _cachedTxtRecords = mutableMapOf<String, CachedDnsTxtRecord>() + private val _cachedPtrRecords = mutableMapOf<String, MutableList<CachedDnsPtrRecord>>() + private val _cachedSrvRecords = mutableMapOf<String, CachedDnsSrvRecord>() + private val _currentServices = mutableListOf<DnsService>() + private var _cts: Job? = null + + var onServicesUpdated: ((List<DnsService>) -> Unit)? = null + + fun start() { + synchronized(_lockObject) { + if (_cts != null) throw Exception("Already started.") + + _cts = CoroutineScope(Dispatchers.Default).launch { + while (isActive) { + val now = Date() + synchronized(_currentServices) { + _cachedAddressRecords.forEach { it.value.removeAll { record -> now.after(record.expirationTime) } } + _cachedTxtRecords.entries.removeIf { now.after(it.value.expirationTime) } + _cachedSrvRecords.entries.removeIf { now.after(it.value.expirationTime) } + _cachedPtrRecords.forEach { it.value.removeAll { record -> now.after(record.expirationTime) } } + + val newServices = getCurrentServices() + _currentServices.clear() + _currentServices.addAll(newServices) + } + + onServicesUpdated?.invoke(_currentServices) + delay(5000) + } + } + } + } + + fun stop() { + synchronized(_lockObject) { + _cts?.cancel() + _cts = null + } + } + + fun add(packet: DnsPacket) { + val dnsResourceRecords = packet.answers + packet.additionals + packet.authorities + val txtRecords = dnsResourceRecords.filter { it.type == ResourceRecordType.TXT.value.toInt() }.map { it to it.getDataReader().readTXTRecord() } + val aRecords = dnsResourceRecords.filter { it.type == ResourceRecordType.A.value.toInt() }.map { it to it.getDataReader().readARecord() } + val aaaaRecords = dnsResourceRecords.filter { it.type == ResourceRecordType.AAAA.value.toInt() }.map { it to it.getDataReader().readAAAARecord() } + val srvRecords = dnsResourceRecords.filter { it.type == ResourceRecordType.SRV.value.toInt() }.map { it to it.getDataReader().readSRVRecord() } + val ptrRecords = dnsResourceRecords.filter { it.type == ResourceRecordType.PTR.value.toInt() }.map { it to it.getDataReader().readPTRRecord() } + + /*val builder = StringBuilder() + builder.appendLine("Received records:") + srvRecords.forEach { builder.appendLine(" ${it.first.name} ${it.first.type} ${it.first.clazz} TTL ${it.first.timeToLive}: (Port: ${it.second.port}, Target: ${it.second.target}, Priority: ${it.second.priority}, Weight: ${it.second.weight})") } + ptrRecords.forEach { builder.appendLine(" ${it.first.name} ${it.first.type} ${it.first.clazz} TTL ${it.first.timeToLive}: ${it.second.domainName}") } + txtRecords.forEach { builder.appendLine(" ${it.first.name} ${it.first.type} ${it.first.clazz} TTL ${it.first.timeToLive}: ${it.second.texts.joinToString(", ")}") } + aRecords.forEach { builder.appendLine(" ${it.first.name} ${it.first.type} ${it.first.clazz} TTL ${it.first.timeToLive}: ${it.second.address}") } + aaaaRecords.forEach { builder.appendLine(" ${it.first.name} ${it.first.type} ${it.first.clazz} TTL ${it.first.timeToLive}: ${it.second.address}") } + synchronized(lockObject) { + // Save to file if necessary + }*/ + + val currentServices: MutableList<DnsService> + synchronized(this._currentServices) { + ptrRecords.forEach { record -> + val cachedPtrRecord = _cachedPtrRecords.getOrPut(record.first.name) { mutableListOf() } + val newPtrRecord = CachedDnsPtrRecord(Date(System.currentTimeMillis() + record.first.timeToLive.toLong() * 1000L), record.second.domainName) + cachedPtrRecord.replaceOrAdd(newPtrRecord) { it.target == record.second.domainName } + + aRecords.forEach { aRecord -> + val cachedARecord = _cachedAddressRecords.getOrPut(aRecord.first.name) { mutableListOf() } + val newARecord = CachedDnsAddressRecord(Date(System.currentTimeMillis() + aRecord.first.timeToLive.toLong() * 1000L), aRecord.second.address) + cachedARecord.replaceOrAdd(newARecord) { it.address == newARecord.address } + } + + aaaaRecords.forEach { aaaaRecord -> + val cachedAaaaRecord = _cachedAddressRecords.getOrPut(aaaaRecord.first.name) { mutableListOf() } + val newAaaaRecord = CachedDnsAddressRecord(Date(System.currentTimeMillis() + aaaaRecord.first.timeToLive.toLong() * 1000L), aaaaRecord.second.address) + cachedAaaaRecord.replaceOrAdd(newAaaaRecord) { it.address == newAaaaRecord.address } + } + } + + txtRecords.forEach { txtRecord -> + _cachedTxtRecords[txtRecord.first.name] = CachedDnsTxtRecord(Date(System.currentTimeMillis() + txtRecord.first.timeToLive.toLong() * 1000L), txtRecord.second.texts) + } + + srvRecords.forEach { srvRecord -> + _cachedSrvRecords[srvRecord.first.name] = CachedDnsSrvRecord(Date(System.currentTimeMillis() + srvRecord.first.timeToLive.toLong() * 1000L), srvRecord.second) + } + + currentServices = getCurrentServices() + this._currentServices.clear() + this._currentServices.addAll(currentServices) + } + + onServicesUpdated?.invoke(currentServices) + } + + fun getAllQuestions(serviceName: String): List<DnsQuestion> { + val questions = mutableListOf<DnsQuestion>() + synchronized(_currentServices) { + val servicePtrRecords = _cachedPtrRecords[serviceName] ?: return emptyList() + + val ptrWithoutSrvRecord = servicePtrRecords.filterNot { _cachedSrvRecords.containsKey(it.target) }.map { it.target } + questions.addAll(ptrWithoutSrvRecord.flatMap { s -> + listOf( + DnsQuestion( + name = s, + type = QuestionType.SRV.value.toInt(), + clazz = QuestionClass.IN.value.toInt(), + queryUnicast = false + ) + ) + }) + + val incompleteCurrentServices = _currentServices.filter { it.addresses.isEmpty() && it.name.endsWith(serviceName) } + questions.addAll(incompleteCurrentServices.flatMap { s -> + listOf( + DnsQuestion( + name = s.name, + type = QuestionType.TXT.value.toInt(), + clazz = QuestionClass.IN.value.toInt(), + queryUnicast = false + ), + DnsQuestion( + name = s.target, + type = QuestionType.A.value.toInt(), + clazz = QuestionClass.IN.value.toInt(), + queryUnicast = false + ), + DnsQuestion( + name = s.target, + type = QuestionType.AAAA.value.toInt(), + clazz = QuestionClass.IN.value.toInt(), + queryUnicast = false + ) + ) + }) + } + return questions + } + + private fun getCurrentServices(): MutableList<DnsService> { + val currentServices = _cachedSrvRecords.map { (key, value) -> + DnsService( + name = key, + target = value.service.target, + port = value.service.port + ) + }.toMutableList() + + currentServices.forEach { service -> + _cachedAddressRecords[service.target]?.let { + service.addresses.addAll(it.map { record -> record.address }) + } + } + + currentServices.forEach { service -> + service.pointers.addAll(_cachedPtrRecords.filter { it.value.any { ptr -> ptr.target == service.name } }.map { it.key }) + } + + currentServices.forEach { service -> + _cachedTxtRecords[service.name]?.let { + service.texts.addAll(it.texts) + } + } + + return currentServices + } + + private inline fun <T> MutableList<T>.replaceOrAdd(newElement: T, predicate: (T) -> Boolean) { + val index = indexOfFirst(predicate) + if (index >= 0) { + this[index] = newElement + } else { + add(newElement) + } + } +} diff --git a/app/src/main/java/com/futo/platformplayer/views/FeedStyle.kt b/app/src/main/java/com/futo/platformplayer/views/FeedStyle.kt index c261b755..9750f382 100644 --- a/app/src/main/java/com/futo/platformplayer/views/FeedStyle.kt +++ b/app/src/main/java/com/futo/platformplayer/views/FeedStyle.kt @@ -16,7 +16,7 @@ enum class FeedStyle(val value: Int) { fun fromInt(value: Int): FeedStyle { - val result = FeedStyle.values().firstOrNull { it.value == value }; + val result = FeedStyle.entries.firstOrNull { it.value == value }; if(result == null) throw UnknownPlatformException(value.toString()); return result; diff --git a/app/src/main/res/drawable/battery_full_24px.xml b/app/src/main/res/drawable/battery_full_24px.xml new file mode 100644 index 00000000..af90cb27 --- /dev/null +++ b/app/src/main/res/drawable/battery_full_24px.xml @@ -0,0 +1,9 @@ +<vector xmlns:android="http://schemas.android.com/apk/res/android" + android:width="24dp" + android:height="24dp" + android:viewportWidth="960" + android:viewportHeight="960"> + <path + android:fillColor="@android:color/white" + android:pathData="M320,880Q303,880 291.5,868.5Q280,857 280,840L280,200Q280,183 291.5,171.5Q303,160 320,160L400,160L400,80L560,80L560,160L640,160Q657,160 668.5,171.5Q680,183 680,200L680,840Q680,857 668.5,868.5Q657,880 640,880L320,880Z"/> +</vector> diff --git a/app/src/test/java/com/futo/platformplayer/MdnsTests.kt b/app/src/test/java/com/futo/platformplayer/MdnsTests.kt new file mode 100644 index 00000000..64a37d6e --- /dev/null +++ b/app/src/test/java/com/futo/platformplayer/MdnsTests.kt @@ -0,0 +1,394 @@ +package com.futo.platformplayer + +import com.futo.platformplayer.mdns.DnsOpcode +import com.futo.platformplayer.mdns.DnsPacket +import com.futo.platformplayer.mdns.DnsPacketHeader +import com.futo.platformplayer.mdns.DnsQuestion +import com.futo.platformplayer.mdns.DnsReader +import com.futo.platformplayer.mdns.DnsResponseCode +import com.futo.platformplayer.mdns.DnsWriter +import com.futo.platformplayer.mdns.QueryResponse +import com.futo.platformplayer.mdns.QuestionClass +import com.futo.platformplayer.mdns.QuestionType +import com.futo.platformplayer.mdns.ResourceRecordClass +import com.futo.platformplayer.mdns.ResourceRecordType +import junit.framework.TestCase.assertEquals +import junit.framework.TestCase.assertTrue +import java.io.ByteArrayOutputStream +import java.net.InetAddress +import kotlin.test.Test +import kotlin.test.assertContentEquals + + +class MdnsTests { + + @Test + fun `BasicOperation`() { + val expectedData = byteArrayOf( + 0x00, 0x01, + 0x00, 0x00, 0x00, 0x02, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, + 0x00, 0x01, + 0x00, 0x00, 0x00, 0x02, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03 + ) + + val writer = DnsWriter() + writer.write(1.toUShort()) + writer.write(2.toUInt()) + writer.write(3.toULong()) + writer.write(1.toShort()) + writer.write(2) + writer.write(3L) + + assertContentEquals(expectedData, writer.toByteArray()) + } + + @Test + fun `DnsQuestionFormat`() { + val expectedBytes = ubyteArrayOf( + 0x00u, 0x00u, 0x00u, 0x00u, 0x00u, 0x01u, 0x00u, 0x00u, 0x00u, 0x00u, 0x00u, 0x00u, 0x08u, 0x5fu, 0x61u, 0x69u, 0x72u, 0x70u, 0x6cu, 0x61u, 0x79u, 0x04u, 0x5fu, 0x74u, 0x63u, 0x70u, 0x05u, 0x6cu, 0x6fu, 0x63u, 0x61u, 0x6cu, 0x00u, 0x00u, 0x0cu, 0x00u, 0x01u + ).asByteArray() + + val writer = DnsWriter() + writer.writePacket( + header = DnsPacketHeader( + identifier = 0.toUShort(), + queryResponse = QueryResponse.Query.value.toInt(), + opcode = DnsOpcode.StandardQuery.value.toInt(), + authoritativeAnswer = false, + truncated = false, + recursionDesired = false, + recursionAvailable = false, + answerAuthenticated = false, + nonAuthenticatedData = false, + responseCode = DnsResponseCode.NoError + ), + questionCount = 1, + questionWriter = { w, _ -> + w.write(DnsQuestion( + name = "_airplay._tcp.local", + type = QuestionType.PTR.value.toInt(), + clazz = QuestionClass.IN.value.toInt(), + queryUnicast = false + )) + } + ) + + assertContentEquals(expectedBytes, writer.toByteArray()) + } + + @Test + fun `BeyondTests`() { + val data = byteArrayOf( + 0x00, 0x01, + 0x00, 0x00, 0x00, 0x02, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, + 0x00, 0x01, + 0x00, 0x00, 0x00, 0x02, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03 + ) + + val reader = DnsReader(data) + assertEquals(1, reader.readInt16()) + assertEquals(2, reader.readInt32()) + assertEquals(3L, reader.readInt64()) + assertEquals(1.toUShort(), reader.readUInt16()) + assertEquals(2.toUInt(), reader.readUInt32()) + assertEquals(3.toULong(), reader.readUInt64()) + } + + @Test + fun `ParseDnsPrinter`() { + val data = ubyteArrayOf( + 0x00u, 0x00u, + 0x84u, 0x00u, 0x00u, 0x00u, 0x00u, 0x01u, 0x00u, 0x00u, 0x00u, 0x06u, 0x04u, 0x5fu, 0x69u, 0x70u, 0x70u, 0x04u, + 0x5fu, 0x74u, 0x63u, 0x70u, 0x05u, 0x6cu, 0x6fu, 0x63u, 0x61u, 0x6cu, 0x00u, 0x00u, 0x0cu, 0x00u, 0x01u, 0x00u, + 0x00u, 0x11u, 0x94u, 0x00u, 0x1eu, 0x1bu, 0x42u, 0x72u, 0x6fu, 0x74u, 0x68u, 0x65u, 0x72u, 0x20u, 0x44u, 0x43u, + 0x50u, 0x2du, 0x4cu, 0x33u, 0x35u, 0x35u, 0x30u, 0x43u, 0x44u, 0x57u, 0x20u, 0x73u, 0x65u, 0x72u, 0x69u, 0x65u, + 0x73u, 0xc0u, 0x0cu, 0xc0u, 0x27u, 0x00u, 0x10u, 0x80u, 0x01u, 0x00u, 0x00u, 0x11u, 0x94u, 0x02u, 0x53u, 0x09u, + 0x74u, 0x78u, 0x74u, 0x76u, 0x65u, 0x72u, 0x73u, 0x3du, 0x31u, 0x08u, 0x71u, 0x74u, 0x6fu, 0x74u, 0x61u, 0x6cu, + 0x3du, 0x31u, 0x42u, 0x70u, 0x64u, 0x6cu, 0x3du, 0x61u, 0x70u, 0x70u, 0x6cu, 0x69u, 0x63u, 0x61u, 0x74u, 0x69u, + 0x6fu, 0x6eu, 0x2fu, 0x6fu, 0x63u, 0x74u, 0x65u, 0x74u, 0x2du, 0x73u, 0x74u, 0x72u, 0x65u, 0x61u, 0x6du, 0x2cu, + 0x69u, 0x6du, 0x61u, 0x67u, 0x65u, 0x2fu, 0x75u, 0x72u, 0x66u, 0x2cu, 0x69u, 0x6du, 0x61u, 0x67u, 0x65u, 0x2fu, + 0x6au, 0x70u, 0x65u, 0x67u, 0x2cu, 0x69u, 0x6du, 0x61u, 0x67u, 0x65u, 0x2fu, 0x70u, 0x77u, 0x67u, 0x2du, 0x72u, + 0x61u, 0x73u, 0x74u, 0x65u, 0x72u, 0x0cu, 0x72u, 0x70u, 0x3du, 0x69u, 0x70u, 0x70u, 0x2fu, 0x70u, 0x72u, 0x69u, + 0x6eu, 0x74u, 0x05u, 0x6eu, 0x6fu, 0x74u, 0x65u, 0x3du, 0x1eu, 0x74u, 0x79u, 0x3du, 0x42u, 0x72u, 0x6fu, 0x74u, + 0x68u, 0x65u, 0x72u, 0x20u, 0x44u, 0x43u, 0x50u, 0x2du, 0x4cu, 0x33u, 0x35u, 0x35u, 0x30u, 0x43u, 0x44u, 0x57u, + 0x20u, 0x73u, 0x65u, 0x72u, 0x69u, 0x65u, 0x73u, 0x25u, 0x70u, 0x72u, 0x6fu, 0x64u, 0x75u, 0x63u, 0x74u, 0x3du, + 0x28u, 0x42u, 0x72u, 0x6fu, 0x74u, 0x68u, 0x65u, 0x72u, 0x20u, 0x44u, 0x43u, 0x50u, 0x2du, 0x4cu, 0x33u, 0x35u, + 0x35u, 0x30u, 0x43u, 0x44u, 0x57u, 0x20u, 0x73u, 0x65u, 0x72u, 0x69u, 0x65u, 0x73u, 0x29u, 0x3cu, 0x61u, 0x64u, + 0x6du, 0x69u, 0x6eu, 0x75u, 0x72u, 0x6cu, 0x3du, 0x68u, 0x74u, 0x74u, 0x70u, 0x3au, 0x2fu, 0x2fu, 0x42u, 0x52u, + 0x57u, 0x31u, 0x30u, 0x35u, 0x42u, 0x41u, 0x44u, 0x34u, 0x41u, 0x31u, 0x35u, 0x37u, 0x30u, 0x2eu, 0x6cu, 0x6fu, + 0x63u, 0x61u, 0x6cu, 0x2eu, 0x2fu, 0x6eu, 0x65u, 0x74u, 0x2fu, 0x6eu, 0x65u, 0x74u, 0x2fu, 0x61u, 0x69u, 0x72u, + 0x70u, 0x72u, 0x69u, 0x6eu, 0x74u, 0x2eu, 0x68u, 0x74u, 0x6du, 0x6cu, 0x0bu, 0x70u, 0x72u, 0x69u, 0x6fu, 0x72u, + 0x69u, 0x74u, 0x79u, 0x3du, 0x32u, 0x35u, 0x0fu, 0x75u, 0x73u, 0x62u, 0x5fu, 0x4du, 0x46u, 0x47u, 0x3du, 0x42u, + 0x72u, 0x6fu, 0x74u, 0x68u, 0x65u, 0x72u, 0x1bu, 0x75u, 0x73u, 0x62u, 0x5fu, 0x4du, 0x44u, 0x4cu, 0x3du, 0x44u, + 0x43u, 0x50u, 0x2du, 0x4cu, 0x33u, 0x35u, 0x35u, 0x30u, 0x43u, 0x44u, 0x57u, 0x20u, 0x73u, 0x65u, 0x72u, 0x69u, + 0x65u, 0x73u, 0x19u, 0x75u, 0x73u, 0x62u, 0x5fu, 0x43u, 0x4du, 0x44u, 0x3du, 0x50u, 0x4au, 0x4cu, 0x2cu, 0x50u, + 0x43u, 0x4cu, 0x2cu, 0x50u, 0x43u, 0x4cu, 0x58u, 0x4cu, 0x2cu, 0x55u, 0x52u, 0x46u, 0x07u, 0x43u, 0x6fu, 0x6cu, + 0x6fu, 0x72u, 0x3du, 0x54u, 0x08u, 0x43u, 0x6fu, 0x70u, 0x69u, 0x65u, 0x73u, 0x3du, 0x54u, 0x08u, 0x44u, 0x75u, + 0x70u, 0x6cu, 0x65u, 0x78u, 0x3du, 0x54u, 0x05u, 0x46u, 0x61u, 0x78u, 0x3du, 0x46u, 0x06u, 0x53u, 0x63u, 0x61u, + 0x6eu, 0x3du, 0x54u, 0x0du, 0x50u, 0x61u, 0x70u, 0x65u, 0x72u, 0x43u, 0x75u, 0x73u, 0x74u, 0x6fu, 0x6du, 0x3du, + 0x54u, 0x08u, 0x42u, 0x69u, 0x6eu, 0x61u, 0x72u, 0x79u, 0x3du, 0x54u, 0x0du, 0x54u, 0x72u, 0x61u, 0x6eu, 0x73u, + 0x70u, 0x61u, 0x72u, 0x65u, 0x6eu, 0x74u, 0x3du, 0x54u, 0x06u, 0x54u, 0x42u, 0x43u, 0x50u, 0x3du, 0x46u, 0x3eu, + 0x55u, 0x52u, 0x46u, 0x3du, 0x53u, 0x52u, 0x47u, 0x42u, 0x32u, 0x34u, 0x2cu, 0x57u, 0x38u, 0x2cu, 0x43u, 0x50u, + 0x31u, 0x2cu, 0x49u, 0x53u, 0x34u, 0x2du, 0x31u, 0x2cu, 0x4du, 0x54u, 0x31u, 0x2du, 0x33u, 0x2du, 0x34u, 0x2du, + 0x35u, 0x2du, 0x38u, 0x2du, 0x31u, 0x31u, 0x2cu, 0x4fu, 0x42u, 0x31u, 0x30u, 0x2cu, 0x50u, 0x51u, 0x34u, 0x2cu, + 0x52u, 0x53u, 0x36u, 0x30u, 0x30u, 0x2cu, 0x56u, 0x31u, 0x2eu, 0x34u, 0x2cu, 0x44u, 0x4du, 0x31u, 0x25u, 0x6bu, + 0x69u, 0x6eu, 0x64u, 0x3du, 0x64u, 0x6fu, 0x63u, 0x75u, 0x6du, 0x65u, 0x6eu, 0x74u, 0x2cu, 0x65u, 0x6eu, 0x76u, + 0x65u, 0x6cu, 0x6fu, 0x70u, 0x65u, 0x2cu, 0x6cu, 0x61u, 0x62u, 0x65u, 0x6cu, 0x2cu, 0x70u, 0x6fu, 0x73u, 0x74u, + 0x63u, 0x61u, 0x72u, 0x64u, 0x11u, 0x50u, 0x61u, 0x70u, 0x65u, 0x72u, 0x4du, 0x61u, 0x78u, 0x3du, 0x6cu, 0x65u, + 0x67u, 0x61u, 0x6cu, 0x2du, 0x41u, 0x34u, 0x29u, 0x55u, 0x55u, 0x49u, 0x44u, 0x3du, 0x65u, 0x33u, 0x32u, 0x34u, + 0x38u, 0x30u, 0x30u, 0x30u, 0x2du, 0x38u, 0x30u, 0x63u, 0x65u, 0x2du, 0x31u, 0x31u, 0x64u, 0x62u, 0x2du, 0x38u, + 0x30u, 0x30u, 0x30u, 0x2du, 0x33u, 0x63u, 0x32u, 0x61u, 0x66u, 0x34u, 0x61u, 0x61u, 0x63u, 0x30u, 0x61u, 0x34u, + 0x0cu, 0x70u, 0x72u, 0x69u, 0x6eu, 0x74u, 0x5fu, 0x77u, 0x66u, 0x64u, 0x73u, 0x3du, 0x54u, 0x14u, 0x6du, 0x6fu, + 0x70u, 0x72u, 0x69u, 0x61u, 0x2du, 0x63u, 0x65u, 0x72u, 0x74u, 0x69u, 0x66u, 0x69u, 0x65u, 0x64u, 0x3du, 0x31u, + 0x2eu, 0x33u, 0x0fu, 0x42u, 0x52u, 0x57u, 0x31u, 0x30u, 0x35u, 0x42u, 0x41u, 0x44u, 0x34u, 0x41u, 0x31u, 0x35u, + 0x37u, 0x30u, 0xc0u, 0x16u, 0x00u, 0x01u, 0x80u, 0x01u, 0x00u, 0x00u, 0x00u, 0x78u, 0x00u, 0x04u, 0xc0u, 0xa8u, + 0x01u, 0xc5u, 0xc2u, 0xa4u, 0x00u, 0x1cu, 0x80u, 0x01u, 0x00u, 0x00u, 0x00u, 0x78u, 0x00u, 0x10u, 0xfeu, 0x80u, + 0x00u, 0x00u, 0x00u, 0x00u, 0x00u, 0x00u, 0x12u, 0x5bu, 0xadu, 0xffu, 0xfeu, 0x4au, 0x15u, 0x70u, 0xc0u, 0x27u, + 0x00u, 0x21u, 0x80u, 0x01u, 0x00u, 0x00u, 0x00u, 0x78u, 0x00u, 0x08u, 0x00u, 0x00u, 0x00u, 0x00u, 0x02u, 0x77u, + 0xc2u, 0xa4u, 0xc0u, 0x27u, 0x00u, 0x2fu, 0x80u, 0x01u, 0x00u, 0x00u, 0x11u, 0x94u, 0x00u, 0x09u, 0xc0u, 0x27u, + 0x00u, 0x05u, 0x00u, 0x00u, 0x80u, 0x00u, 0x40u, 0xc2u, 0xa4u, 0x00u, 0x2fu, 0x80u, 0x01u, 0x00u, 0x00u, 0x00u, + 0x78u, 0x00u, 0x08u, 0xc2u, 0xa4u, 0x00u, 0x04u, 0x40u, 0x00u, 0x00u, 0x08u + ) + + val packet = DnsPacket.parse(data.asByteArray()) + assertEquals(QueryResponse.Response.value.toInt(), packet.header.queryResponse) + assertEquals(DnsOpcode.StandardQuery.value.toInt(), packet.header.opcode) + assertTrue(packet.header.authoritativeAnswer) + assertEquals(false, packet.header.truncated) + assertEquals(false, packet.header.recursionDesired) + assertEquals(false, packet.header.recursionAvailable) + assertEquals(false, packet.header.answerAuthenticated) + assertEquals(false, packet.header.nonAuthenticatedData) + assertEquals(DnsResponseCode.NoError, packet.header.responseCode) + assertEquals(0, packet.questions.size) + assertEquals(1, packet.answers.size) + assertEquals(0, packet.authorities.size) + assertEquals(6, packet.additionals.size) + + val firstAnswer = packet.answers[0] + assertEquals("_ipp._tcp.local", firstAnswer.name) + assertEquals(ResourceRecordType.PTR.value.toInt(), firstAnswer.type) + assertEquals(ResourceRecordClass.IN.value.toInt(), firstAnswer.clazz) + assertEquals(false, firstAnswer.cacheFlush) + assertEquals(4500u, firstAnswer.timeToLive) + assertEquals(30, firstAnswer.dataLength) + assertEquals("Brother DCP-L3550CDW series._ipp._tcp.local", firstAnswer.getDataReader().readPTRRecord().domainName) + + val firstAdditional = packet.additionals[0] + assertEquals("Brother DCP-L3550CDW series._ipp._tcp.local", firstAdditional.name) + assertEquals(ResourceRecordType.TXT.value.toInt(), firstAdditional.type) + assertEquals(ResourceRecordClass.IN.value.toInt(), firstAdditional.clazz) + assertEquals(true, firstAdditional.cacheFlush) + assertEquals(4500u, firstAdditional.timeToLive) + assertEquals(595, firstAdditional.dataLength) + + val txtRecord = firstAdditional.getDataReader().readTXTRecord() + assertContentEquals(arrayOf( + "txtvers=1", + "qtotal=1", + "pdl=application/octet-stream,image/urf,image/jpeg,image/pwg-raster", + "rp=ipp/print", + "note=", + "ty=Brother DCP-L3550CDW series", + "product=(Brother DCP-L3550CDW series)", + "adminurl=http://BRW105BAD4A1570.local./net/net/airprint.html", + "priority=25", + "usb_MFG=Brother", + "usb_MDL=DCP-L3550CDW series", + "usb_CMD=PJL,PCL,PCLXL,URF", + "Color=T", + "Copies=T", + "Duplex=T", + "Fax=F", + "Scan=T", + "PaperCustom=T", + "Binary=T", + "Transparent=T", + "TBCP=F", + "URF=SRGB24,W8,CP1,IS4-1,MT1-3-4-5-8-11,OB10,PQ4,RS600,V1.4,DM1", + "kind=document,envelope,label,postcard", + "PaperMax=legal-A4", + "UUID=e3248000-80ce-11db-8000-3c2af4aac0a4", + "print_wfds=T", + "mopria-certified=1.3" + ), txtRecord.texts.toTypedArray()) + + val aRecord = packet.additionals[1].getDataReader().readARecord() + assertEquals(InetAddress.getByName("192.168.1.197"), aRecord.address) + + val aaaaRecord = packet.additionals[2].getDataReader().readAAAARecord() + assertEquals(InetAddress.getByName("fe80::125b:adff:fe4a:1570"), aaaaRecord.address) + + val srvRecord = packet.additionals[3].getDataReader().readSRVRecord() + assertEquals("BRW105BAD4A1570.local", srvRecord.target) + assertEquals(0, srvRecord.weight.toInt()) + assertEquals(0, srvRecord.priority.toInt()) + assertEquals(631, srvRecord.port.toInt()) + + val nSECRecord = packet.additionals[4].getDataReader().readNSECRecord() + assertEquals("Brother DCP-L3550CDW series._ipp._tcp.local", nSECRecord.ownerName) + assertEquals(1, nSECRecord.typeBitMaps.size) + assertEquals(0, nSECRecord.typeBitMaps[0].first) + assertContentEquals(byteArrayOf(0, 0, 128.toByte(), 0, 64), nSECRecord.typeBitMaps[0].second) + } + + @Test + fun `ParseSamsungTV`() { + val data = loadByteArray("samsung-airplay.hex") + val packet = DnsPacket.parse(data) + assertEquals(QueryResponse.Response.value.toInt(), packet.header.queryResponse) + assertEquals(DnsOpcode.StandardQuery.value.toInt(), packet.header.opcode) + assertTrue(packet.header.authoritativeAnswer) + assertEquals(false, packet.header.truncated) + assertEquals(false, packet.header.recursionDesired) + assertEquals(false, packet.header.recursionAvailable) + assertEquals(false, packet.header.answerAuthenticated) + assertEquals(false, packet.header.nonAuthenticatedData) + assertEquals(DnsResponseCode.NoError, packet.header.responseCode) + assertEquals(0, packet.questions.size) + assertEquals(6, packet.answers.size) + assertEquals(0, packet.authorities.size) + assertEquals(4, packet.additionals.size) + + assertEquals("9.1.168.192.in-addr.arpa", packet.answers[0].name) + assertEquals(ResourceRecordType.PTR.value.toInt(), packet.answers[0].type) + assertEquals(ResourceRecordClass.IN.value.toInt(), packet.answers[0].clazz) + assertTrue(packet.answers[0].cacheFlush) + assertEquals(120u, packet.answers[0].timeToLive) + assertEquals(15, packet.answers[0].dataLength) + assertEquals("Samsung.local", packet.answers[0].getDataReader().readPTRRecord().domainName) + + val txtRecord = packet.answers[1].getDataReader().readTXTRecord() + assertContentEquals(arrayOf( + "acl=0", + "deviceid=D4:9D:C0:2F:52:16", + "features=0x7F8AD0,0x38BCB46", + "rsf=0x3", + "fv=p20.0.1", + "flags=0x244", + "model=URU8000", + "manufacturer=Samsung", + "serialNumber=0EQC3HDM900064X", + "protovers=1.1", + "srcvers=377.17.24.6", + "pi=ED:0C:A5:ED:10:08", + "psi=00000000-0000-0000-0000-ED0CA5ED1008", + "gid=00000000-0000-0000-0000-ED0CA5ED1008", + "gcgl=0", + "pk=d25488cbff1334756165cd7229a235475ef591f2595f38ed251d46b8a4d2345d" + ), txtRecord.texts.toTypedArray()) + + val srvRecord = packet.answers[4].getDataReader().readSRVRecord() + assertEquals(33482, srvRecord.port.toInt()) + assertEquals(0, srvRecord.priority.toInt()) + assertEquals(0, srvRecord.weight.toInt()) + assertEquals("Samsung.local", srvRecord.target) + + val aRecord = packet.answers[5].getDataReader().readARecord() + assertEquals(InetAddress.getByName("192.168.1.9"), aRecord.address) + + val nSECRecord = packet.additionals[0].getDataReader().readNSECRecord() + assertEquals("9.1.168.192.in-addr.arpa", nSECRecord.ownerName) + assertEquals(1, nSECRecord.typeBitMaps.size) + assertEquals(0, nSECRecord.typeBitMaps[0].first) + assertContentEquals(byteArrayOf(0, 8), nSECRecord.typeBitMaps[0].second) + + val optRecord = packet.additionals[3].getDataReader().readOPTRecord() + assertEquals(1, optRecord.options.size) + assertEquals(65001, optRecord.options[0].code.toInt()) + assertEquals(5, optRecord.options[0].data.size) + assertContentEquals(byteArrayOf(0, 0, 116, 206.toByte(), 97), optRecord.options[0].data) + } + + @Test + fun `UnicodeTest`() { + val data = ubyteArrayOf( + 0x00u, 0x00u, 0x84u, 0x00u, 0x00u, 0x00u, 0x00u, 0x01u, 0x00u, 0x00u, 0x00u, 0x01u, 0x15u, 0x41u, 0x69u, 0x64u, + 0x61u, 0x6Eu, 0xE2u, 0x80u, 0x99u, 0x73u, 0x20u, 0x4Du, 0x61u, 0x63u, 0x42u, 0x6Fu, 0x6Fu, 0x6Bu, 0x20u, 0x50u, + 0x72u, 0x6Fu, 0x0Fu, 0x5Fu, 0x63u, 0x6Fu, 0x6Du, 0x70u, 0x61u, 0x6Eu, 0x69u, 0x6Fu, 0x6Eu, 0x2Du, 0x6Cu, 0x69u, + 0x6Eu, 0x6Bu, 0x04u, 0x5Fu, 0x74u, 0x63u, 0x70u, 0x05u, 0x6Cu, 0x6Fu, 0x63u, 0x61u, 0x6Cu, 0x00u, 0x00u, 0x10u, + 0x80u, 0x01u, 0x00u, 0x00u, 0x11u, 0x94u, 0x00u, 0x5Bu, 0x16u, 0x72u, 0x70u, 0x42u, 0x41u, 0x3Du, 0x30u, 0x33u, + 0x3Au, 0x43u, 0x32u, 0x3Au, 0x33u, 0x33u, 0x3Au, 0x38u, 0x36u, 0x3Au, 0x33u, 0x43u, 0x3Au, 0x45u, 0x45u, 0x11u, + 0x72u, 0x70u, 0x41u, 0x44u, 0x3Du, 0x66u, 0x33u, 0x33u, 0x37u, 0x61u, 0x38u, 0x61u, 0x32u, 0x38u, 0x64u, 0x35u, + 0x31u, 0x0Cu, 0x72u, 0x70u, 0x46u, 0x6Cu, 0x3Du, 0x30u, 0x78u, 0x32u, 0x30u, 0x30u, 0x30u, 0x30u, 0x11u, 0x72u, + 0x70u, 0x48u, 0x4Eu, 0x3Du, 0x31u, 0x66u, 0x66u, 0x64u, 0x64u, 0x64u, 0x66u, 0x33u, 0x63u, 0x39u, 0x65u, 0x33u, + 0x07u, 0x72u, 0x70u, 0x4Du, 0x61u, 0x63u, 0x3Du, 0x30u, 0x0Au, 0x72u, 0x70u, 0x56u, 0x72u, 0x3Du, 0x33u, 0x36u, + 0x30u, 0x2Eu, 0x34u, 0xC0u, 0x0Cu, 0x00u, 0x2Fu, 0x80u, 0x01u, 0x00u, 0x00u, 0x11u, 0x94u, 0x00u, 0x09u, 0xC0u, + 0x0Cu, 0x00u, 0x05u, 0x00u, 0x00u, 0x80u, 0x00u, 0x40u + ) + + val packet = DnsPacket.parse(data.asByteArray()) + assertEquals("Aidan’s MacBook Pro._companion-link._tcp.local", packet.additionals[0].name) + } + + /*@Test + fun `TestReadDomainName`() { + val data = ubyteArrayOf( + 0x00u, 0x00u, 0x00u, 0x00u, 0x00u, 0x04u, 0x00u, 0x00u, 0x00u, 0x00u, 0x00u, 0x00u, 0x0Bu, 0x5Fu, 0x67u, 0x6Fu, + 0x6Fu, 0x67u, 0x6Cu, 0x65u, 0x63u, 0x61u, 0x73u, 0x74u, 0x04u, 0x5Fu, 0x74u, 0x63u, 0x70u, 0x05u, 0x6Cu, 0x6Fu, + 0x63u, 0x61u, 0x6Cu, 0xC0u, 0x0Cu, 0x00u, 0x0Cu, 0x00u, 0x01u, 0x08u, 0x5Fu, 0x61u, 0x69u, 0x72u, 0x70u, 0x6Cu, + 0x61u, 0x79u, 0xC0u, 0x18u, 0x00u, 0x0Cu, 0x00u, 0x01u, 0x09u, 0x5Fu, 0x66u, 0x61u, 0x73u, 0x74u, 0x63u, 0x61u, + 0x73u, 0x74u, 0xC0u, 0x18u, 0x00u, 0x0Cu, 0x00u, 0x01u, 0x06u, 0x5Fu, 0x66u, 0x63u, 0x61u, 0x73u, 0x74u, 0xC0u, + 0x18u, 0x00u, 0x0Cu, 0x00u, 0x01u + ) + + val packet = DnsPacket.parse(data.asByteArray()) + println() + }*/ + + private fun loadByteArray(name: String): ByteArray { + javaClass.classLoader.getResourceAsStream(name).use { input -> + requireNotNull(input) { "File not found: $name" } + val result = ByteArrayOutputStream() + val buffer = ByteArray(4096) + var length: Int + + while ((input.read(buffer).also { length = it }) > 0) { + result.write(buffer, 0, length) + } + return result.toByteArray() + } + } + + @Test + fun `ReserializeDnsPrinter`() { + val data = loadByteArray("samsung-airplay.hex") + val packet = DnsPacket.parse(data) + val writer = DnsWriter() + writer.writePacket( + header = packet.header, + questionCount = packet.questions.size, + questionWriter = { _, _ -> }, + answerCount = packet.answers.size, + answerWriter = { w, i -> + w.write(packet.answers[i]) { v -> + val reader = packet.answers[i].getDataReader() + when (i) { + 0, 2, 3 -> v.write(reader.readPTRRecord()) + 1 -> v.write(reader.readTXTRecord()) + 4 -> v.write(reader.readSRVRecord()) + 5 -> v.write(reader.readARecord()) + } + } + }, + authorityCount = packet.authorities.size, + authorityWriter = { _, _ -> }, + additionalsCount = packet.additionals.size, + additionalWriter = { w, i -> + w.write(packet.additionals[i]) { v -> + val reader = packet.additionals[i].getDataReader() + when (i) { + 0, 1, 2 -> v.write(reader.readNSECRecord()) + 3 -> v.write(reader.readOPTRecord()) + } + } + } + ) + + assertContentEquals(data, writer.toByteArray()) + } +} \ No newline at end of file diff --git a/app/src/test/resources/samsung-airplay.hex b/app/src/test/resources/samsung-airplay.hex new file mode 100644 index 0000000000000000000000000000000000000000..8938268ca32c62ea4b892dcaafc3dafe32cfce85 GIT binary patch literal 642 zcmY+C&2G~`5XZ-DJ|PuUlnZbOw^kx%cfD(`Em>`9r{V$<RFSwv8?PN%IEm#=qHwDa zkI)C;$Yc1-f%o7ckXRcdg_U+S^Ve@ivoizm74+%o+cI1OA7h&zW0!Q()ksKLv_z2$ zfWe)ii9I~(-ioE#t`_ZV9f=H{HtvuQZ{s43)$p9U7rmKCi#!vbTeD4+mmD5f;ivaT z-WE~jqh~VSrBR$pKct=;dI9ptjmL<G?dM4>HrparKH57sd>o=zXm9cB;F{XqLM2+Z z9wa+IC&)x5ewbupQL%*5ljT~*nLqt7<p`m(rC4nf5mm_wzbW&D(nX5w-F7+GBAUDp ztT*9XSK~JQc#;?E&3YFX%Eu;tqKc@tEyppjV-jlGXL;&RLJtMrn0ea3$V2=*SE;Wv z)Mr$^2hSu#!I(`#j5N5=AO4H|$J8N*U}q66beDqslP?LQoJaE{!IniGW@DR0(jmkZ z#9~^A6XxQCFqb73k2Q`ZwdY(=Ni51_f2Mz2^;&gja-~K}wq{#djbs4zNZ_cPmeT|8 zvb=<s&9r;fYJB-oq7s48Xsz<Q(Ju$^y74>OK7yW(s+H4w=xe(Ta0gd9>>t*yQd|LC Rw7*rG0R7*;Droal`~l^-qyYc` literal 0 HcmV?d00001 -- GitLab