Last active
May 7, 2024 18:25
-
-
Save rodydavis/b86ed505b83e1871bb9e1cbab4b44bfc to your computer and use it in GitHub Desktop.
Flutter AI Theme Generation (Function Calling)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Copyright 2024 Google LLC | |
// | |
// Licensed under the Apache License, Version 2.0 (the "License"); | |
// you may not use this file except in compliance with the License. | |
// You may obtain a copy of the License at | |
// | |
// http://www.apache.org/licenses/LICENSE-2.0 | |
// | |
// Unless required by applicable law or agreed to in writing, software | |
// distributed under the License is distributed on an "AS IS" BASIS, | |
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
// See the License for the specific language governing permissions and | |
// limitations under the License. | |
import 'dart:async'; | |
import 'package:flutter/material.dart'; | |
import 'package:flutter_markdown/flutter_markdown.dart'; | |
import 'package:google_generative_ai/google_generative_ai.dart'; | |
import 'package:url_launcher/link.dart'; | |
final themeColor = ValueNotifier<Color>(Colors.orangeAccent); | |
final themeMode = ValueNotifier<ThemeMode>(ThemeMode.light); | |
final textScaleFactor = ValueNotifier<double>(1); | |
void main() { | |
runApp(const GenerativeAISample()); | |
} | |
class GenerativeAISample extends StatelessWidget { | |
const GenerativeAISample({super.key}); | |
ThemeData theme(Brightness brightness) { | |
final colors = ColorScheme.fromSeed( | |
brightness: brightness, | |
seedColor: themeColor.value, | |
); | |
return ThemeData( | |
brightness: brightness, | |
colorScheme: colors, | |
scaffoldBackgroundColor: colors.surface, | |
); | |
} | |
@override | |
Widget build(BuildContext context) { | |
return AnimatedBuilder( | |
animation: Listenable.merge([ | |
themeColor, | |
themeMode, | |
textScaleFactor, | |
]), | |
builder: (context, child) { | |
return MaterialApp( | |
debugShowCheckedModeBanner: false, | |
title: 'Flutter + GenAI', | |
theme: theme(Brightness.light), | |
darkTheme: theme(Brightness.dark), | |
themeMode: themeMode.value, | |
builder: (context, child) { | |
return MediaQuery( | |
data: MediaQuery.of(context) | |
.copyWith(textScaleFactor: textScaleFactor.value), | |
child: child!, | |
); | |
}, | |
home: const ChatScreen(title: 'Theme Editor'), | |
); | |
}, | |
); | |
} | |
} | |
class ChatScreen extends StatefulWidget { | |
const ChatScreen({ | |
super.key, | |
required this.title, | |
}); | |
final String title; | |
@override | |
State<ChatScreen> createState() => _ChatScreenState(); | |
} | |
class _ChatScreenState extends State<ChatScreen> { | |
String? apiKey; | |
@override | |
Widget build(BuildContext context) { | |
return switch (apiKey) { | |
final providedKey? => Example( | |
title: widget.title, | |
apiKey: providedKey, | |
), | |
_ => ApiKeyWidget( | |
title: widget.title, | |
onSubmitted: (key) { | |
setState(() => apiKey = key); | |
}, | |
), | |
}; | |
} | |
} | |
class Example extends StatefulWidget { | |
const Example({ | |
super.key, | |
required this.apiKey, | |
required this.title, | |
}); | |
final String apiKey, title; | |
@override | |
State<Example> createState() => _ExampleState(); | |
} | |
class _ExampleState extends State<Example> { | |
final loading = ValueNotifier(false); | |
final menu = ValueNotifier(''); | |
final messages = ValueNotifier<List<(Sender, String)>>([]); | |
final controller = TextEditingController(); | |
late final _history = <Content>[]; | |
late final model = GenerativeModel( | |
model: 'gemini-pro', | |
apiKey: widget.apiKey, | |
requestOptions: const RequestOptions(apiVersion: 'v1beta'), | |
tools: [ | |
Tool( | |
functionDeclarations: <FunctionDeclaration>[ | |
FunctionDeclaration( | |
'change_theme_color', | |
'Change the current theme color', | |
Schema( | |
SchemaType.object, | |
properties: { | |
'hex': Schema( | |
SchemaType.string, | |
description: 'Must be 6 in length. FF00EE,000000,FFFFFF', | |
), | |
}, | |
), | |
), | |
FunctionDeclaration( | |
'change_theme_mode', | |
'Change the current theme mode', | |
Schema( | |
SchemaType.object, | |
properties: { | |
'mode': Schema( | |
SchemaType.string, | |
description: | |
'Must be one of the following: light,dark,system', | |
), | |
}, | |
), | |
), | |
FunctionDeclaration( | |
'change_text_scale_factor', | |
'Change the current font scale, where 1 represents 14px and 2.0 = 48px', | |
Schema( | |
SchemaType.object, | |
properties: { | |
'scale': Schema( | |
SchemaType.number, | |
description: 'Valid font scale, defaults to 1.0', | |
), | |
}, | |
), | |
), | |
], | |
), | |
], | |
); | |
Future<void> sendMessage() async { | |
final message = controller.text.trim(); | |
if (message.isEmpty) return; | |
controller.clear(); | |
addMessage(Sender.user, message); | |
loading.value = true; | |
try { | |
final prompt = StringBuffer(); | |
prompt.writeln(message); | |
final response = await callWithActions([Content.text(prompt.toString())]); | |
if (response.text != null) { | |
addMessage(Sender.system, response.text!); | |
} else { | |
addMessage(Sender.system, 'Something went wrong, please try again.'); | |
} | |
} catch (e) { | |
addMessage(Sender.system, 'Error sending message: $e'); | |
} finally { | |
loading.value = false; | |
} | |
} | |
Future<GenerateContentResponse> callWithActions( | |
Iterable<Content> prompt, | |
) async { | |
final response = await model.generateContent( | |
_history.followedBy(prompt), | |
); | |
if (response.candidates.isNotEmpty) { | |
_history.addAll(prompt); | |
_history.add(response.candidates.first.content); | |
} | |
final actions = <FunctionResponse>[]; | |
for (final fn in response.functionCalls) { | |
final args = fn.args; | |
switch (fn.name) { | |
case 'change_theme_color': | |
final hex = args['hex'] as String; | |
if (hex.length != 6) { | |
actions.add(FunctionResponse(fn.name, { | |
'type': 'Error', | |
'message': 'hex must be exactly 6 characters', | |
})); | |
} else { | |
themeColor.value = Color(int.parse('0xFF$hex')); | |
actions.add(FunctionResponse(fn.name, { | |
'type': 'Success', | |
'message': 'theme color updated', | |
})); | |
} | |
break; | |
case 'change_theme_mode': | |
final mode = args['mode'] as String; | |
themeMode.value = switch (mode) { | |
'system' => ThemeMode.system, | |
'light' => ThemeMode.light, | |
'dark' => ThemeMode.dark, | |
(_) => ThemeMode.system, | |
}; | |
actions.add(FunctionResponse(fn.name, { | |
'type': 'Success', | |
'message': 'theme mode updated', | |
})); | |
break; | |
case 'change_text_scale_factor': | |
final value = args['scale'] as num; | |
textScaleFactor.value = value.toDouble(); | |
actions.add(FunctionResponse(fn.name, { | |
'type': 'Success', | |
'message': 'font scale updated', | |
})); | |
break; | |
default: | |
} | |
} | |
if (actions.isNotEmpty) { | |
return await callWithActions([ | |
...prompt, | |
if (response.functionCalls.isNotEmpty) | |
Content.model(response.functionCalls), | |
for (final res in actions) | |
Content.functionResponse(res.name, res.response), | |
]); | |
} | |
return response; | |
} | |
void addMessage(Sender sender, String value, {bool clear = false}) { | |
if (clear) { | |
_history.clear(); | |
messages.value = []; | |
} | |
messages.value = messages.value.toList()..add((sender, value)); | |
} | |
@override | |
Widget build(BuildContext context) { | |
return AnimatedBuilder( | |
animation: messages, | |
builder: (context, child) { | |
final reversed = messages.value.reversed; | |
return Scaffold( | |
appBar: AppBar( | |
title: Text(widget.title), | |
), | |
body: messages.value.isEmpty | |
? const Center(child: Text('Start changing the theme!')) | |
: ListView.builder( | |
padding: const EdgeInsets.all(8), | |
reverse: true, | |
itemCount: reversed.length, | |
itemBuilder: (context, index) { | |
final (sender, message) = reversed.elementAt(index); | |
return MessageWidget( | |
isFromUser: sender == Sender.user, | |
text: message, | |
); | |
}, | |
), | |
bottomNavigationBar: BottomAppBar( | |
padding: const EdgeInsets.all(8), | |
child: Row( | |
children: [ | |
Expanded( | |
child: TextField( | |
controller: controller, | |
decoration: textFieldDecoration( | |
context, | |
'Change the theme color, font scale factor or brightness', | |
), | |
onEditingComplete: sendMessage, | |
onSubmitted: (value) => sendMessage(), | |
), | |
), | |
const SizedBox(width: 8), | |
AnimatedBuilder( | |
animation: loading, | |
builder: (context, _) { | |
if (loading.value) { | |
return const CircularProgressIndicator(); | |
} | |
return IconButton( | |
onPressed: sendMessage, | |
icon: const Icon(Icons.send), | |
tooltip: 'Send a message', | |
); | |
}, | |
), | |
], | |
), | |
), | |
); | |
}, | |
); | |
} | |
} | |
enum Sender { | |
user, | |
system, | |
} | |
class MessageWidget extends StatelessWidget { | |
const MessageWidget({ | |
super.key, | |
this.text, | |
this.image, | |
required this.isFromUser, | |
}); | |
final Image? image; | |
final String? text; | |
final bool isFromUser; | |
@override | |
Widget build(BuildContext context) { | |
return Row( | |
mainAxisAlignment: | |
isFromUser ? MainAxisAlignment.end : MainAxisAlignment.start, | |
children: [ | |
Flexible( | |
child: Container( | |
constraints: const BoxConstraints(maxWidth: 520), | |
decoration: BoxDecoration( | |
color: isFromUser | |
? Theme.of(context).colorScheme.primaryContainer | |
: Theme.of(context).colorScheme.surfaceVariant, | |
borderRadius: BorderRadius.circular(18), | |
), | |
padding: const EdgeInsets.symmetric( | |
vertical: 15, | |
horizontal: 20, | |
), | |
margin: const EdgeInsets.only(bottom: 8), | |
child: Column(children: [ | |
if (text case final text?) MarkdownBody(data: text), | |
if (image case final image?) image, | |
]), | |
), | |
), | |
], | |
); | |
} | |
} | |
class ApiKeyWidget extends StatelessWidget { | |
ApiKeyWidget({ | |
super.key, | |
required this.onSubmitted, | |
required this.title, | |
}); | |
final String title; | |
final ValueChanged onSubmitted; | |
final _textController = TextEditingController(); | |
@override | |
Widget build(BuildContext context) { | |
return Scaffold( | |
appBar: AppBar( | |
title: Text(title), | |
), | |
body: Center( | |
child: Padding( | |
padding: const EdgeInsets.all(8.0), | |
child: Column( | |
mainAxisSize: MainAxisSize.min, | |
children: [ | |
const Text( | |
'To use the Gemini API, you\'ll need an API key. ' | |
'If you don\'t already have one, ' | |
'create a key in Google AI Studio.', | |
textAlign: TextAlign.center, | |
), | |
const SizedBox(height: 8), | |
Link( | |
uri: Uri.https('aistudio.google.com', '/app/apikey'), | |
target: LinkTarget.blank, | |
builder: (context, followLink) => TextButton( | |
onPressed: followLink, | |
child: const Text('Get an API Key'), | |
), | |
), | |
], | |
), | |
), | |
), | |
bottomNavigationBar: BottomAppBar( | |
padding: const EdgeInsets.all(8), | |
child: Row( | |
children: [ | |
Expanded( | |
child: TextField( | |
decoration: textFieldDecoration(context, 'Enter your API key'), | |
controller: _textController, | |
obscureText: true, | |
onSubmitted: (value) { | |
onSubmitted(value); | |
}, | |
), | |
), | |
const SizedBox(height: 8), | |
TextButton( | |
onPressed: () { | |
onSubmitted(_textController.value.text); | |
}, | |
child: const Text('Submit'), | |
), | |
], | |
), | |
), | |
); | |
} | |
} | |
InputDecoration textFieldDecoration(BuildContext context, String hintText) { | |
return InputDecoration( | |
contentPadding: const EdgeInsets.all(15), | |
hintText: hintText, | |
border: OutlineInputBorder( | |
borderRadius: const BorderRadius.all( | |
Radius.circular(14), | |
), | |
borderSide: BorderSide( | |
color: Theme.of(context).colorScheme.secondary, | |
), | |
), | |
focusedBorder: OutlineInputBorder( | |
borderRadius: const BorderRadius.all( | |
Radius.circular(14), | |
), | |
borderSide: BorderSide( | |
color: Theme.of(context).colorScheme.secondary, | |
), | |
), | |
); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment