From 13f1a784a1e99aedd8e21984c9145af3ed5b2e0a Mon Sep 17 00:00:00 2001 From: Glenn Date: Tue, 11 Apr 2023 10:42:17 +0200 Subject: [PATCH] Initial import --- .gitignore | 3 + Makefile | 8 +++ README.md | 49 ++++++++++++++ shard.lock | 6 ++ shard.yml | 26 ++++++++ specs/test_data/adam_eve.txt | 35 ++++++++++ src/builders/generic.cr | 3 + src/builders/openai_chat.cr | 53 +++++++++++++++ src/builders/string.cr | 41 ++++++++++++ src/main.cr | 122 +++++++++++++++++++++++++++++++++++ src/parsers/generic.cr | 3 + src/parsers/request.cr | 7 ++ src/parsers/string.cr | 50 ++++++++++++++ src/prompt.cr | 20 ++++++ src/zone.cr | 26 ++++++++ 15 files changed, 452 insertions(+) create mode 100644 .gitignore create mode 100644 Makefile create mode 100644 README.md create mode 100644 shard.lock create mode 100644 shard.yml create mode 100644 specs/test_data/adam_eve.txt create mode 100644 src/builders/generic.cr create mode 100644 src/builders/openai_chat.cr create mode 100644 src/builders/string.cr create mode 100644 src/main.cr create mode 100644 src/parsers/generic.cr create mode 100644 src/parsers/request.cr create mode 100644 src/parsers/string.cr create mode 100644 src/prompt.cr create mode 100644 src/zone.cr diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fa50d30 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ + +/bin +/lib diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..70cf497 --- /dev/null +++ b/Makefile @@ -0,0 +1,8 @@ + +all: build + +build: + shards build + + + diff --git a/README.md b/README.md new file mode 100644 index 0000000..4bc1ed7 --- /dev/null +++ b/README.md @@ -0,0 +1,49 @@ + +Cf interface de NovelAI + +## Fonctionnement + +### Auteur propose un description de chaque personnage, lieu, objet (world info) + + ## Character1 + + ## Character2 + + ## Character3 + +### AI extrait les infos sur les personnages + +Résume tout ce qu'on sait sur le personnage Tintin en JSON, en respectant le format suivant, en faisant des phrases très courtes: + +result=[ +{name: "", body_description: [...], psychology: [...], values: [...], intentions: [...], challenges: [...], relationships: [{name: "...", description: "..."}]} +] + +### Auteur insere des marques de "chapitres" (ex: @@) + +Fais les requetes suivantes, chapitre par chapitre + +### AI propose un résumé succinct (pour chaque chapitre) + +PROMPT: + + Résume succintement CHAPITRE, en étant précis sur les évenements mentionnés. + +### AI propose les enjeux (pour chaque chapitre) + +PROMPT: + + ## Request + + Liste les emotions et intentions de chaque personnages au format JSON. En respectant la structure suivante. + + result={ + "character1": { fears: [...], intentions: [...], emotions: [...], hidden_desires: [...] } + "character2": { fears: [...], intentions: [...], emotions: [...], hidden_desires: [...] } + } + + ## Answer + + result={ + + diff --git a/shard.lock b/shard.lock new file mode 100644 index 0000000..980beb4 --- /dev/null +++ b/shard.lock @@ -0,0 +1,6 @@ +version: 2.0 +shards: + openai: + git: https://github.com/lancecarlson/openai.cr.git + version: 0.1.0+git.commit.852bcd9b37d8472a4c72a2498ebb90351048fa68 + diff --git a/shard.yml b/shard.yml new file mode 100644 index 0000000..ed456d4 --- /dev/null +++ b/shard.yml @@ -0,0 +1,26 @@ +name: ai-storyteller +version: 0.1.0 + +targets: + storyteller: + main: src/main.cr + +# authors: +# - name + +# description: | +# Short description of ai-storyteller + +dependencies: + openai: + github: lancecarlson/openai.cr + +# pg: +# github: will/crystal-pg +# version: "~> 0.5" + +# development_dependencies: +# webmock: +# github: manastech/webmock.cr + +# license: MIT diff --git a/specs/test_data/adam_eve.txt b/specs/test_data/adam_eve.txt new file mode 100644 index 0000000..3d9459d --- /dev/null +++ b/specs/test_data/adam_eve.txt @@ -0,0 +1,35 @@ +story--some-nice-story + +@@system +## Characters + +Eve est une femme + +Adam est un homme + +## Context + +Adam et Eve vivent au Paradis. +Ils profitent d'une vie heureuse jusque là. + +## Synopsis + +Eve rencontre un serpent et mange le fruit défendu. +((imagine la suite, la réaction d'Adam, etc.)) + +## Contraintes + +Ecris en français, au présent de l'indicatif. + +@@before +## Récit + +Eve se balade dans les jardins d'Eden. Le soleil est haut dans le ciel. + +@@before +Eve a chaud. Elle se refraichirait bien dans la riviere. + +Elle descend dans +@@after + +This text will not be used in the chat diff --git a/src/builders/generic.cr b/src/builders/generic.cr new file mode 100644 index 0000000..a6b300c --- /dev/null +++ b/src/builders/generic.cr @@ -0,0 +1,3 @@ +abstract class PromptGenericBuilder + abstract def build(prompt : Prompt) +end diff --git a/src/builders/openai_chat.cr b/src/builders/openai_chat.cr new file mode 100644 index 0000000..5eb7491 --- /dev/null +++ b/src/builders/openai_chat.cr @@ -0,0 +1,53 @@ + +require "pretty_print" +require "colorize" + +class OpenAIChatBuilder + alias OpenAIMessage = NamedTuple(role: String, content: String) + alias OpenAIChat = Array(OpenAIMessage) + + getter verbose : Bool + def initialize(@verbose) + + end + + # skip prelude_zone + # skip future_zone + def build(prompt : Prompt) : OpenAIChat + chat = [] of OpenAIMessage + + token_limit = 2_900 + mandatory_token_count = ( + prompt.system_zone.token_count + + prompt.present_zone.token_count + ) + + ## Build mandatory system messages + prompt.system_zone.each do |content| + chat << { role: "system", content: content } + end + + ## Build mandatory system messages + tmp_chat = [] of OpenAIMessage + tmp_token_count = 0 + prompt.past_zone.reverse_each do |content| + estimated_token_count = (content.size/4) + tmp_token_count + mandatory_token_count + pp tmp_chat.reverse if @verbose + puts "ESTIMATE: #{estimated_token_count} (limit=#{token_limit})".colorize(:yellow).to_s if @verbose + + break if estimated_token_count >= token_limit + + tmp_chat << { role: "user", content: content } + tmp_token_count += (content.size / 4) + end + chat.concat(tmp_chat.reverse) + + prompt.present_zone.each do |content| + chat << { role: "user", content: content } + end + + # pp chat + chat + end +end + diff --git a/src/builders/string.cr b/src/builders/string.cr new file mode 100644 index 0000000..febcc99 --- /dev/null +++ b/src/builders/string.cr @@ -0,0 +1,41 @@ + +require "colorize" + +require "./generic" + +class StringBuilder < PromptGenericBuilder + getter use_color : Bool + def initialize(@use_color) + Colorize.enabled = @use_color + end + + def build(prompt : Prompt) + str = "" + + prompt.prelude_zone.each do |content| + str += content + end + + prompt.system_zone.each do |content| + str += "@@system".colorize(:yellow).to_s + str += content + end + + prompt.past_zone.each do |content| + str += "@@before".colorize(:yellow).to_s + str += content + end + + prompt.present_zone.each do |content| + str += "@@current".colorize(:yellow).to_s + str += content.colorize(:light_cyan).to_s + end + + prompt.future_zone.each do |content| + str += "@@after".colorize(:yellow).to_s + str += content + end + + str + end +end diff --git a/src/main.cr b/src/main.cr new file mode 100644 index 0000000..72d47a5 --- /dev/null +++ b/src/main.cr @@ -0,0 +1,122 @@ +require "option_parser" +require "pretty_print" +require "openai" + +require "./zone" +require "./parsers/string" +require "./builders/string" +require "./builders/openai_chat" + +class Storyteller + + def initialize() + end + + def self.start(argv) + input_file = STDIN + input_file_path = "" + output_file = STDOUT + output_file_path = "" + past_characters_limit = 1000 + use_color = true + make_request = true + verbose = false + + parser = OptionParser.parse do |parser| + parser.banner = "Usage: storyteller [options]" + + parser.on("-i FILE", "--input=FILE", "Path to input file") do |file| + input_file_path = file + end + + parser.on("-v", "--verbose", "Be verbose (cumulative)") do + verbose = true + end + + parser.on("--dry-run", "Don't call the API") do + make_request = false + end + + parser.on("-n", "--no-color", "Disable color output") do + use_color = false + end + + parser.on("-o FILE", "--output=FILE", "Path to output file") do |file| + use_color = false + output_file_path = file + end + + parser.on("-h", "--help", "Show this help") do + puts parser + exit + end + end + + # Create Storyteller instance + storyteller = Storyteller.new() + + # Read file and initialize zones + if !input_file_path.empty? + # puts "d: Using input file #{input_file_path}" + input_file = File.open(input_file_path) + end + prompt = storyteller.read_file(input_file) + input_file.close + + # Build GPT-3 request + prompt = storyteller.complete(prompt, make_request, verbose) + exit 0 if !make_request + + if !output_file_path.empty? + # puts "d: Using output file #{input_file_path}" + output_file = File.open(output_file_path, "w") + end + storyteller.write_file(output_file, prompt, use_color) + output_file.close + end + + def complete(prompt : Prompt, make_request : Bool, verbose : Bool) + builder = OpenAIChatBuilder.new(verbose: verbose) + messages = builder.build(prompt) + + return prompt if !make_request + + openai = OpenAI::Client.new(access_token: ENV.fetch("OPENAI_API_KEY")) + result = openai.chat( + "gpt-3.5-turbo", + messages, + { + "temperature" => 0.82, + "presence_penalty" => 1, + "frequency_penalty" => 1, + "max_tokens" => 256 + } + ) + prompt.present_zone.content << "\n" + result.choices.first["message"]["content"] + "\n" + prompt + end + + def read_file(input_file : IO::FileDescriptor) + content = input_file.gets_to_end + + # puts "d: building parser" + parser = StringParser.new + # puts "d: parsing" + prompt = parser.parse(content) + # pp prompt + end + + def write_file(output_file : IO::FileDescriptor, prompt : Prompt, use_color : Bool) + # STDERR.puts "d: building builder" + builder = StringBuilder.new(use_color) + # STDERR.puts "d: building" + text = builder.build(prompt) + output_file.write_string(text.to_slice) + end + + def display_completion(completion : String) + # Code pour afficher la complétion + end +end + +Storyteller.start(ARGV) diff --git a/src/parsers/generic.cr b/src/parsers/generic.cr new file mode 100644 index 0000000..bde35a3 --- /dev/null +++ b/src/parsers/generic.cr @@ -0,0 +1,3 @@ +abstract class PromptGenericParser + abstract def parse : Prompt +end diff --git a/src/parsers/request.cr b/src/parsers/request.cr new file mode 100644 index 0000000..43555d8 --- /dev/null +++ b/src/parsers/request.cr @@ -0,0 +1,7 @@ + +class PromptRequestParser + getter prompt + + def initialize(@prompt = prompt) + end +end diff --git a/src/parsers/string.cr b/src/parsers/string.cr new file mode 100644 index 0000000..8fa210c --- /dev/null +++ b/src/parsers/string.cr @@ -0,0 +1,50 @@ + +require "../prompt" + +class StringParser + def parse(current : String) : Prompt + prompt = Prompt.new + remaining = read_string(prompt.zones, current) + prompt.prelude_zone.content << remaining + + prompt.system_zone.content.reverse! + prompt.past_zone.content.reverse! + prompt.present_zone.content.reverse! + prompt.future_zone.content.reverse! + + return prompt + end + + def read_string(zone_list : Array(Zone), current : String) + # puts "== read_string(current=#{current})" + + pos = current.index("@@") + + ## If there is no remaining @@, then return current + if pos.nil? + # puts "-- no remaining @@, returning" + return current + end + + ## If @@ is not at position 0, then parse its content first and return remains + if pos > 0 + return current[0..(pos-1)] + read_string(zone_list, current[pos..]) + end + + ## If @@ is at position 0, try detecting tag + zone = zone_list.find { |zone| current.starts_with?("@@" + zone.tag) } + + ## When there is not recognizable tag, skip fake tag and parse + if zone.nil? + # puts "-- no recognizable tag, returning as is" + return "@@" + read_string(zone_list, current[2..]) + end + + ## Handle recognized tag, skip tag & parse & add remains + # puts "-- found tag #{zone.tag}" + remaining = read_string(zone_list, current[(2+zone.tag.size)..]) + zone.content << remaining + return "" + end +end + diff --git a/src/prompt.cr b/src/prompt.cr new file mode 100644 index 0000000..09931aa --- /dev/null +++ b/src/prompt.cr @@ -0,0 +1,20 @@ +require "./zone" + +class Prompt + getter prelude_zone = Zone.new() + getter system_zone = Zone.new(tag: "system") + getter past_zone = Zone.new(tag: "before") + getter present_zone = Zone.new(tag: "current") + getter future_zone = Zone.new(tag: "after") + + def zones() : Array(Zone) + return [ + system_zone, + past_zone, + present_zone, + future_zone + ] + end +end + + diff --git a/src/zone.cr b/src/zone.cr new file mode 100644 index 0000000..3ba0c06 --- /dev/null +++ b/src/zone.cr @@ -0,0 +1,26 @@ +class Zone + property tag : String + property content : Array(String) + + def initialize(@tag = "", @content = [] of String) + end + + def token_count() + token_count = 0 + + self.each do |content| + token_count += content.size / 4 + end + + return token_count + end + + def each(&block) + content.each { |item| yield item } + end + + def reverse_each(&block) + content.reverse_each { |item| yield item } + end +end +