using CommandLine; using Discord; using Discord.Commands; using Discord.Rest; using Discord.WebSocket; using Fleck; using System.Collections.Concurrent; using System.Reflection; using System.Runtime.CompilerServices; namespace MinecraftDiscordBot; public class Program : IDisposable, ICommandHandler { public const string WebSocketSource = "WebSocket"; public const string BotSource = "Bot"; private static readonly object LogLock = new(); public const int ChoiceTimeout = 20 * 1000; private readonly DiscordSocketClient _client = new(new() { LogLevel = LogSeverity.Verbose, GatewayIntents = GatewayIntents.AllUnprivileged & ~(GatewayIntents.GuildScheduledEvents | GatewayIntents.GuildInvites) }); private readonly WebSocketServer _wssv; private readonly BotConfiguration _config; private readonly HashSet _whitelistedChannels; private readonly ConcurrentDictionary _connections = new(); private static readonly char[] WhiteSpace = new char[] { '\t', '\n', ' ', '\r' }; public ITextChannel[] _channels = Array.Empty(); private ConnectedComputer? _rsSystem = null; private bool disposedValue; public static bool OnlineNotifications => false; public static readonly string ClientScript = GetClientScript(); private readonly ITokenProvider _tokenProvider = new TimeoutTokenProvider(InstanceId, 10); private static readonly int InstanceId = new Random().Next(); private string GetVerifiedClientScript() => ClientScript.Replace("$TOKEN", _tokenProvider.GenerateToken()); private static string GetClientScript() { using var stream = Assembly.GetExecutingAssembly().GetManifestResourceStream("MinecraftDiscordBot.ClientScript.lua"); if (stream is null) throw new FileNotFoundException("Client script could not be loaded!"); using var sr = new StreamReader(stream); return sr.ReadToEnd(); } public ConnectedComputer? Computer { get => _rsSystem; set { if (_rsSystem != value) { _rsSystem = value; if (OnlineNotifications) _ = Task.Run(() => Broadcast(i => i.SendMessageAsync(value is null ? $"The Refined Storage went offline. Please check the server!" : $"The Refined Storage is back online!"))); } } } private async Task Broadcast(Func> message) => _ = await Task.WhenAll(_channels.Select(message)); public Program(BotConfiguration config) { _config = config; _client.Log += LogAsync; _client.MessageReceived += (msg) => DiscordMessageReceived(msg); _client.ReactionAdded += DiscordReactionAdded; _wssv = new WebSocketServer($"ws://0.0.0.0:{config.Port}") { RestartAfterListenError = true }; FleckLog.LogAction = LogWebSocket; _whitelistedChannels = config.Channels.ToHashSet(); } private void LogWebSocket(LogLevel level, string message, Exception exception) => Log(new(level switch { LogLevel.Debug => LogSeverity.Debug, LogLevel.Info => LogSeverity.Info, LogLevel.Warn => LogSeverity.Warning, LogLevel.Error => LogSeverity.Error, _ => LogSeverity.Critical // Unknown logging states should behave critical }, WebSocketSource, message, exception)); public static Task Main(string[] args) => Parser.Default.ParseArguments(args) .MapResult>( RunWithConfig, RunWithConfig, errs => Task.FromResult(1)); private static Task RunWithConfig(IBotConfigurator arg) => new Program(arg.Config).RunAsync(); public async Task RunAsync() { StartWebSocketServer(); await _client.LoginAsync(TokenType.Bot, _config.Token); await _client.StartAsync(); if (!await HasValidChannels()) return 1; // Block this task until the program is closed. await Task.Delay(-1); return 0; } private async Task HasValidChannels() { if (await GetValidChannels(_whitelistedChannels).ToArrayAsync() is not { Length: > 0 } channels) { await LogErrorAsync(BotSource, new InvalidOperationException("No valid textchannel was whitelisted!")); return false; } _channels = channels; return true; } private void StartWebSocketServer() => _wssv.Start(socket => { socket.OnOpen = async () => await SocketOpened(socket); socket.OnClose = async () => await SocketClosed(socket); socket.OnMessage = async message => await SocketReceived(socket, message); }); private async IAsyncEnumerable GetValidChannels(IEnumerable ids) { foreach (var channelId in ids) { var channel = await _client.GetChannelAsync(channelId); if (channel is not ITextChannel textChannel) { if (channel is null) await LogWarningAsync(BotSource, $"Channel with id [{channelId}] does not exist!"); else await LogWarningAsync(BotSource, $"Channel is not a text channels and will not be used: {channel.Name} [{channel.Id}]!"); continue; } if (textChannel.Guild is RestGuild guild) { await guild.UpdateAsync(); await LogInfoAsync(BotSource, $"Whitelisted in channel: {channel.Name} [{channel.Id}] on server {guild.Name} [{guild.Id}]"); } else { await LogWarningAsync(BotSource, $"Whitelisted in channel: {channel.Name} [{channel.Id}] on unknown server!"); } yield return textChannel; } } private async Task SocketReceived(IWebSocketConnection socket, string message) { await LogInfoAsync(WebSocketSource, $"[{socket.ConnectionInfo.Id}] Received: {message}"); await (message switch { "getcode" => SendClientCode(socket), string s when s.StartsWith("login=") => ClientComputerConnected(socket, s[6..]), _ => DisruptClientConnection(socket, "Protocol violation!") }); } private async Task ClientComputerConnected(IWebSocketConnection socket, string token) { if (!_tokenProvider.VerifyToken(token)) { await DisruptClientConnection(socket, "outdated"); return; } await LogInfoAsync(WebSocketSource, $"[{socket.ConnectionInfo.Id}] Client logged in with valid script!"); AddComputerSocket(socket, new(socket)); } private static async Task DisruptClientConnection(IWebSocketConnection socket, string reason) { await socket.Send(reason); await LogWarningAsync(WebSocketSource, $"[{socket.ConnectionInfo.Id}] Client will be terminated, reason: {reason}"); socket.Close(); } private async Task SendClientCode(IWebSocketConnection socket) { await socket.Send(GetVerifiedClientScript()); await LogInfoAsync(WebSocketSource, $"[{socket.ConnectionInfo.Id}] Script sent to client!"); } private void AddComputerSocket(IWebSocketConnection socket, ConnectedComputer pc) => Computer = pc; private void RemoveComputerSocket(IWebSocketConnection socket) { if (Computer is { ConnectionInfo.Id: Guid id } && id == socket.ConnectionInfo.Id) Computer = null; } private async Task SocketClosed(IWebSocketConnection socket) { RemoveComputerSocket(socket); await LogInfoAsync(WebSocketSource, $"[{socket.ConnectionInfo.Id}] Client disconnected!"); } private static async Task SocketOpened(IWebSocketConnection socket) => await LogInfoAsync(WebSocketSource, $"[{socket.ConnectionInfo.Id}] Client connected from {socket.ConnectionInfo.ClientIpAddress}:{socket.ConnectionInfo.ClientPort}!"); private async Task DiscordMessageReceived(SocketMessage arg, int timeout = 10000) { if (arg is not SocketUserMessage message) return; if (message.Author.IsBot) return; if (!IsChannelWhitelisted(arg.Channel)) return; var cts = new CancellationTokenSource(timeout); if (IsCommand(message, out var argPos)) { var parameters = message.Content[argPos..].Split(WhiteSpace, StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); _ = Task.Run(async () => { var response = await HandleCommand(message, parameters, cts.Token); await SendResponse(message, response); }); return; } await LogInfoAsync("Discord", $"[{arg.Author.Username}] {arg.Content}"); // TODO: Relay Message to Chat Receiver } private Task SendResponse(SocketUserMessage message, ResponseType response) => response switch { ResponseType.IChoiceResponse res => HandleChoice(message, res), ResponseType.StringResponse res => message.ReplyAsync(res.Message), _ => message.ReplyAsync($"Whoops, someone forgot to implement '{response.GetType()}' responses?"), }; private readonly ConcurrentDictionary _choiceWait = new(); private async Task DiscordReactionAdded(Cacheable message, Cacheable channel, SocketReaction reaction) { var msgObject = await message.GetOrDownloadAsync(); if (reaction.UserId == _client.CurrentUser.Id) return; if (!_choiceWait.TryRemove(message.Id, out var choice)) { await LogInfoAsync(BotSource, "Reaction was added to message without choice object!"); return; } await msgObject.DeleteAsync(); await LogInfoAsync(BotSource, $"Reaction {reaction.Emote.Name} was added to the choice by {reaction.UserId}!"); } private async Task HandleChoice(SocketUserMessage message, ResponseType.IChoiceResponse res) { var reply = await message.ReplyAsync($"{res.Query}\n{string.Join("\n", res.Options)}"); _choiceWait[reply.Id] = res; var reactions = new Emoji[] { new("0️⃣")/*, new("1️⃣"), new("2️⃣"), new("3️⃣"), new("4️⃣"), new("5️⃣"), new("6️⃣"), new("7️⃣"), new("8️⃣"), new("9️⃣")*/ }; await reply.AddReactionsAsync(reactions); _ = Task.Run(async () => { await Task.Delay(ChoiceTimeout); _ = _choiceWait.TryRemove(message.Id, out _); await reply.ModifyAsync(i => i.Content = "You did not choose in time!"); await reply.RemoveAllReactionsAsync(); }); } public async Task HandleCommand(SocketUserMessage message, string[] parameters, CancellationToken ct) { if (Computer is ICommandHandler handler) try { return await handler.HandleCommand(message, parameters, ct); } catch (TaskCanceledException) { return ResponseType.AsString("Your request could not be processed in time!"); } catch (ReplyException e) { await LogWarningAsync(BotSource, e.Message); return ResponseType.AsString($"Your request failed: {e.Message}"); } catch (Exception e) { await LogErrorAsync(BotSource, e); return ResponseType.AsString($"Oopsie doopsie, this should not have happened!"); } else return ResponseType.AsString("The Minecraft server is currently unavailable!"); } private bool IsCommand(SocketUserMessage message, out int argPos) { argPos = 0; return message.HasStringPrefix(_config.Prefix, ref argPos); } private bool IsChannelWhitelisted(ISocketMessageChannel channel) => _whitelistedChannels.Contains(channel.Id); public static ConfiguredTaskAwaitable LogInfoAsync(string source, string message) => LogAsync(new(LogSeverity.Info, source, message)).ConfigureAwait(false); public static ConfiguredTaskAwaitable LogWarningAsync(string source, string message) => LogAsync(new(LogSeverity.Warning, source, message)).ConfigureAwait(false); public static ConfiguredTaskAwaitable LogErrorAsync(string source, Exception exception) => LogAsync(new(LogSeverity.Error, source, exception?.Message, exception)).ConfigureAwait(false); public static void LogInfo(string source, string message) => Log(new(LogSeverity.Info, source, message)); public static void LogWarning(string source, string message) => Log(new(LogSeverity.Warning, source, message)); public static void LogError(string source, Exception exception) => Log(new(LogSeverity.Error, source, exception?.Message, exception)); private static async Task LogAsync(LogMessage msg) { Log(msg); await Task.CompletedTask; } public static void Log(LogMessage msg) { lock (LogLock) Console.WriteLine(msg.ToString()); } protected virtual void Dispose(bool disposing) { if (!disposedValue) { if (disposing) { // TODO: dispose managed state (managed objects) _wssv.Dispose(); _client.Dispose(); } // TODO: free unmanaged resources (unmanaged objects) and override finalizer // TODO: set large fields to null disposedValue = true; } } // // TODO: override finalizer only if 'Dispose(bool disposing)' has code to free unmanaged resources // ~Program() // { // // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method // Dispose(disposing: false); // } public void Dispose() { // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method Dispose(disposing: true); GC.SuppressFinalize(this); } } public abstract class ResponseType { private static string DefaultDisplay(T obj) => obj?.ToString() ?? throw new InvalidProgramException("ToString did not yield anything!"); public static ResponseType AsString(string message) => new StringResponse(message); public static ResponseType FromChoice(string query, IEnumerable choice, Func resultHandler, Func? display = null) => new ChoiceResponse(query, choice, resultHandler, display ?? DefaultDisplay); public class StringResponse : ResponseType { public StringResponse(string message) => Message = message; public string Message { get; } } public interface IChoiceResponse { IEnumerable Options { get; } string Query { get; } Task HandleResult(int index); } public class ChoiceResponse : ResponseType, IChoiceResponse { private readonly Func _resultHandler; private readonly T[] _options; private readonly Func _displayer; public IEnumerable Options => _options.Select(_displayer); public string Query { get; } public Task HandleResult(int index) => _resultHandler(_options[index]); public ChoiceResponse(string query, IEnumerable choice, Func resultHandler, Func display) { Query = query; _resultHandler = resultHandler; _options = choice.ToArray(); _displayer = display; } } }